183 lines
7.1 KiB
Python
183 lines
7.1 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Compare all reference distance matrices against obikmer distance outputs.
|
||
|
|
|
||
|
|
Reads from:
|
||
|
|
reference_dist/ — ground-truth matrices computed by build_reference_dist.py
|
||
|
|
obikmer_dist/ — matrices produced by `obikmer distance`
|
||
|
|
|
||
|
|
Handles label reordering: both matrices are sorted by genome label before
|
||
|
|
element-wise comparison, so column/row order differences are irrelevant.
|
||
|
|
|
||
|
|
Output: stats/dist_comparison/summary.csv
|
||
|
|
comparison,max_abs,mean_abs,rmse,n_pairs,status
|
||
|
|
"""
|
||
|
|
import csv
|
||
|
|
import sys
|
||
|
|
from pathlib import Path
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
# ── CSV loading ───────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def load_matrix(path: Path) -> tuple[list[str], np.ndarray]:
|
||
|
|
"""Load a distance-matrix CSV; return (sorted_labels, matrix_float64)."""
|
||
|
|
with path.open() as fh:
|
||
|
|
reader = csv.reader(fh)
|
||
|
|
header = next(reader)[1:] # skip 'genome' column
|
||
|
|
raw: dict[str, list[float]] = {}
|
||
|
|
for row in reader:
|
||
|
|
raw[row[0]] = [float(x) for x in row[1:]]
|
||
|
|
|
||
|
|
label_to_col = {h: i for i, h in enumerate(header)}
|
||
|
|
labels = sorted(raw.keys())
|
||
|
|
n = len(labels)
|
||
|
|
mat = np.zeros((n, n), dtype=np.float64)
|
||
|
|
for i, ri in enumerate(labels):
|
||
|
|
for j, cj in enumerate(labels):
|
||
|
|
mat[i, j] = raw[ri][label_to_col[cj]]
|
||
|
|
return labels, mat
|
||
|
|
|
||
|
|
|
||
|
|
# ── comparison ────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def compare(label: str,
|
||
|
|
ref_path: Path,
|
||
|
|
obi_path: Path,
|
||
|
|
tol: float = 1e-4) -> dict:
|
||
|
|
if not ref_path.exists():
|
||
|
|
return {'comparison': label, 'status': 'REF_MISSING',
|
||
|
|
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
|
||
|
|
if not obi_path.exists():
|
||
|
|
return {'comparison': label, 'status': 'OBI_MISSING',
|
||
|
|
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
|
||
|
|
|
||
|
|
ref_labels, ref_mat = load_matrix(ref_path)
|
||
|
|
obi_labels, obi_mat = load_matrix(obi_path)
|
||
|
|
|
||
|
|
if ref_labels != obi_labels:
|
||
|
|
only_ref = sorted(set(ref_labels) - set(obi_labels))
|
||
|
|
only_obi = sorted(set(obi_labels) - set(ref_labels))
|
||
|
|
print(f' [{label}] label mismatch — '
|
||
|
|
f'only_ref={only_ref} only_obi={only_obi}', file=sys.stderr)
|
||
|
|
return {'comparison': label, 'status': 'LABEL_MISMATCH',
|
||
|
|
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
|
||
|
|
|
||
|
|
n = len(ref_labels)
|
||
|
|
# Off-diagonal mask
|
||
|
|
mask = ~np.eye(n, dtype=bool)
|
||
|
|
diff = np.abs(ref_mat[mask] - obi_mat[mask])
|
||
|
|
n_pairs = diff.size
|
||
|
|
|
||
|
|
max_abs = float(diff.max())
|
||
|
|
mean_abs = float(diff.mean())
|
||
|
|
rmse = float(np.sqrt((diff ** 2).mean()))
|
||
|
|
status = 'PASS' if max_abs <= tol else 'FAIL'
|
||
|
|
|
||
|
|
print(f' [{label}] n={n_pairs} '
|
||
|
|
f'max={max_abs:.3e} mean={mean_abs:.3e} rmse={rmse:.3e} {status}',
|
||
|
|
file=sys.stderr)
|
||
|
|
return {
|
||
|
|
'comparison': label,
|
||
|
|
'max_abs': f'{max_abs:.6e}',
|
||
|
|
'mean_abs': f'{mean_abs:.6e}',
|
||
|
|
'rmse': f'{rmse:.6e}',
|
||
|
|
'n_pairs': str(n_pairs),
|
||
|
|
'status': status,
|
||
|
|
}
|
||
|
|
|
||
|
|
|
||
|
|
# ── comparison table ──────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
# (label, ref_csv, obikmer_csv)
|
||
|
|
# The reference jaccard/shared is presence-based, which should match both
|
||
|
|
# presence/jaccard and count/jaccard (threshold=1).
|
||
|
|
COMPARISONS = [
|
||
|
|
# ── presence index ────────────────────────────────────────────────────────
|
||
|
|
('presence/jaccard_dist',
|
||
|
|
'reference_dist/jaccard_dist.csv',
|
||
|
|
'obikmer_dist/presence/jaccard_dist.csv'),
|
||
|
|
|
||
|
|
('presence/jaccard_shared',
|
||
|
|
'reference_dist/shared_kmers.csv',
|
||
|
|
'obikmer_dist/presence/jaccard_shared.csv'),
|
||
|
|
|
||
|
|
('presence/hamming_dist',
|
||
|
|
'reference_dist/hamming_dist.csv',
|
||
|
|
'obikmer_dist/presence/hamming_dist.csv'),
|
||
|
|
|
||
|
|
# ── count index (jaccard cross-check) ─────────────────────────────────────
|
||
|
|
('count/jaccard_dist',
|
||
|
|
'reference_dist/jaccard_dist.csv',
|
||
|
|
'obikmer_dist/count/jaccard_dist.csv'),
|
||
|
|
|
||
|
|
('count/jaccard_shared',
|
||
|
|
'reference_dist/shared_kmers.csv',
|
||
|
|
'obikmer_dist/count/jaccard_shared.csv'),
|
||
|
|
|
||
|
|
# ── count index (count-based metrics) ────────────────────────────────────
|
||
|
|
('count/bray_curtis_dist',
|
||
|
|
'reference_dist/bray_curtis_dist.csv',
|
||
|
|
'obikmer_dist/count/bray_curtis_dist.csv'),
|
||
|
|
|
||
|
|
('count/relfreq_bray_curtis_dist',
|
||
|
|
'reference_dist/relfreq_bray_curtis_dist.csv',
|
||
|
|
'obikmer_dist/count/relfreq_bray_curtis_dist.csv'),
|
||
|
|
|
||
|
|
('count/euclidean_dist',
|
||
|
|
'reference_dist/euclidean_dist.csv',
|
||
|
|
'obikmer_dist/count/euclidean_dist.csv'),
|
||
|
|
|
||
|
|
('count/relfreq_euclidean_dist',
|
||
|
|
'reference_dist/relfreq_euclidean_dist.csv',
|
||
|
|
'obikmer_dist/count/relfreq_euclidean_dist.csv'),
|
||
|
|
|
||
|
|
('count/hellinger_dist',
|
||
|
|
'reference_dist/hellinger_dist.csv',
|
||
|
|
'obikmer_dist/count/hellinger_dist.csv'),
|
||
|
|
|
||
|
|
('count/hellinger_euclidean_dist',
|
||
|
|
'reference_dist/hellinger_euclidean_dist.csv',
|
||
|
|
'obikmer_dist/count/hellinger_euclidean_dist.csv'),
|
||
|
|
]
|
||
|
|
|
||
|
|
|
||
|
|
# ── main ─────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
import argparse
|
||
|
|
ap = argparse.ArgumentParser(description=__doc__,
|
||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
|
|
ap.add_argument('--tol', type=float, default=1e-4,
|
||
|
|
help='Max abs diff threshold for PASS/FAIL (default 1e-4)')
|
||
|
|
ap.add_argument('--out', default='stats/dist_comparison/summary.csv',
|
||
|
|
help='Output summary CSV path')
|
||
|
|
args = ap.parse_args()
|
||
|
|
|
||
|
|
out_path = Path(args.out)
|
||
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
||
|
|
|
||
|
|
print(f'Comparing {len(COMPARISONS)} matrix pairs…', file=sys.stderr)
|
||
|
|
rows = []
|
||
|
|
for label, ref, obi in COMPARISONS:
|
||
|
|
rows.append(compare(label, Path(ref), Path(obi), tol=args.tol))
|
||
|
|
|
||
|
|
fields = ['comparison', 'max_abs', 'mean_abs', 'rmse', 'n_pairs', 'status']
|
||
|
|
with out_path.open('w', newline='') as fh:
|
||
|
|
w = csv.DictWriter(fh, fieldnames=fields)
|
||
|
|
w.writeheader()
|
||
|
|
w.writerows(rows)
|
||
|
|
|
||
|
|
print(f'\n→ {out_path}', file=sys.stderr)
|
||
|
|
|
||
|
|
n_fail = sum(1 for r in rows if r.get('status') == 'FAIL')
|
||
|
|
n_pass = sum(1 for r in rows if r.get('status') == 'PASS')
|
||
|
|
print(f'Summary: {n_pass} PASS {n_fail} FAIL '
|
||
|
|
f'{len(rows) - n_pass - n_fail} SKIP', file=sys.stderr)
|
||
|
|
if n_fail:
|
||
|
|
sys.exit(1)
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|