Files
obikmer/benchmark/compare_all_dist.py
T

183 lines
7.1 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
"""Compare all reference distance matrices against obikmer distance outputs.
Reads from:
reference_dist/ — ground-truth matrices computed by build_reference_dist.py
obikmer_dist/ — matrices produced by `obikmer distance`
Handles label reordering: both matrices are sorted by genome label before
element-wise comparison, so column/row order differences are irrelevant.
Output: stats/dist_comparison/summary.csv
comparison,max_abs,mean_abs,rmse,n_pairs,status
"""
import csv
import sys
from pathlib import Path
import numpy as np
# ── CSV loading ───────────────────────────────────────────────────────────────
def load_matrix(path: Path) -> tuple[list[str], np.ndarray]:
"""Load a distance-matrix CSV; return (sorted_labels, matrix_float64)."""
with path.open() as fh:
reader = csv.reader(fh)
header = next(reader)[1:] # skip 'genome' column
raw: dict[str, list[float]] = {}
for row in reader:
raw[row[0]] = [float(x) for x in row[1:]]
label_to_col = {h: i for i, h in enumerate(header)}
labels = sorted(raw.keys())
n = len(labels)
mat = np.zeros((n, n), dtype=np.float64)
for i, ri in enumerate(labels):
for j, cj in enumerate(labels):
mat[i, j] = raw[ri][label_to_col[cj]]
return labels, mat
# ── comparison ────────────────────────────────────────────────────────────────
def compare(label: str,
ref_path: Path,
obi_path: Path,
tol: float = 1e-4) -> dict:
if not ref_path.exists():
return {'comparison': label, 'status': 'REF_MISSING',
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
if not obi_path.exists():
return {'comparison': label, 'status': 'OBI_MISSING',
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
ref_labels, ref_mat = load_matrix(ref_path)
obi_labels, obi_mat = load_matrix(obi_path)
if ref_labels != obi_labels:
only_ref = sorted(set(ref_labels) - set(obi_labels))
only_obi = sorted(set(obi_labels) - set(ref_labels))
print(f' [{label}] label mismatch — '
f'only_ref={only_ref} only_obi={only_obi}', file=sys.stderr)
return {'comparison': label, 'status': 'LABEL_MISMATCH',
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
n = len(ref_labels)
# Off-diagonal mask
mask = ~np.eye(n, dtype=bool)
diff = np.abs(ref_mat[mask] - obi_mat[mask])
n_pairs = diff.size
max_abs = float(diff.max())
mean_abs = float(diff.mean())
rmse = float(np.sqrt((diff ** 2).mean()))
status = 'PASS' if max_abs <= tol else 'FAIL'
print(f' [{label}] n={n_pairs} '
f'max={max_abs:.3e} mean={mean_abs:.3e} rmse={rmse:.3e} {status}',
file=sys.stderr)
return {
'comparison': label,
'max_abs': f'{max_abs:.6e}',
'mean_abs': f'{mean_abs:.6e}',
'rmse': f'{rmse:.6e}',
'n_pairs': str(n_pairs),
'status': status,
}
# ── comparison table ──────────────────────────────────────────────────────────
# (label, ref_csv, obikmer_csv)
# The reference jaccard/shared is presence-based, which should match both
# presence/jaccard and count/jaccard (threshold=1).
COMPARISONS = [
# ── presence index ────────────────────────────────────────────────────────
('presence/jaccard_dist',
'reference_dist/jaccard_dist.csv',
'obikmer_dist/presence/jaccard_dist.csv'),
('presence/jaccard_shared',
'reference_dist/shared_kmers.csv',
'obikmer_dist/presence/jaccard_shared.csv'),
('presence/hamming_dist',
'reference_dist/hamming_dist.csv',
'obikmer_dist/presence/hamming_dist.csv'),
# ── count index (jaccard cross-check) ─────────────────────────────────────
('count/jaccard_dist',
'reference_dist/jaccard_dist.csv',
'obikmer_dist/count/jaccard_dist.csv'),
('count/jaccard_shared',
'reference_dist/shared_kmers.csv',
'obikmer_dist/count/jaccard_shared.csv'),
# ── count index (count-based metrics) ────────────────────────────────────
('count/bray_curtis_dist',
'reference_dist/bray_curtis_dist.csv',
'obikmer_dist/count/bray_curtis_dist.csv'),
('count/relfreq_bray_curtis_dist',
'reference_dist/relfreq_bray_curtis_dist.csv',
'obikmer_dist/count/relfreq_bray_curtis_dist.csv'),
('count/euclidean_dist',
'reference_dist/euclidean_dist.csv',
'obikmer_dist/count/euclidean_dist.csv'),
('count/relfreq_euclidean_dist',
'reference_dist/relfreq_euclidean_dist.csv',
'obikmer_dist/count/relfreq_euclidean_dist.csv'),
('count/hellinger_dist',
'reference_dist/hellinger_dist.csv',
'obikmer_dist/count/hellinger_dist.csv'),
('count/hellinger_euclidean_dist',
'reference_dist/hellinger_euclidean_dist.csv',
'obikmer_dist/count/hellinger_euclidean_dist.csv'),
]
# ── main ─────────────────────────────────────────────────────────────────────
def main() -> None:
import argparse
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument('--tol', type=float, default=1e-4,
help='Max abs diff threshold for PASS/FAIL (default 1e-4)')
ap.add_argument('--out', default='stats/dist_comparison/summary.csv',
help='Output summary CSV path')
args = ap.parse_args()
out_path = Path(args.out)
out_path.parent.mkdir(parents=True, exist_ok=True)
print(f'Comparing {len(COMPARISONS)} matrix pairs…', file=sys.stderr)
rows = []
for label, ref, obi in COMPARISONS:
rows.append(compare(label, Path(ref), Path(obi), tol=args.tol))
fields = ['comparison', 'max_abs', 'mean_abs', 'rmse', 'n_pairs', 'status']
with out_path.open('w', newline='') as fh:
w = csv.DictWriter(fh, fieldnames=fields)
w.writeheader()
w.writerows(rows)
print(f'\n{out_path}', file=sys.stderr)
n_fail = sum(1 for r in rows if r.get('status') == 'FAIL')
n_pass = sum(1 for r in rows if r.get('status') == 'PASS')
print(f'Summary: {n_pass} PASS {n_fail} FAIL '
f'{len(rows) - n_pass - n_fail} SKIP', file=sys.stderr)
if n_fail:
sys.exit(1)
if __name__ == '__main__':
main()