469e53b6f5
Introduces scripts to compute and validate pairwise genomic distance matrices across multiple metrics. Updates the Makefile with build and comparison targets, adds .gitignore rules for generated outputs, and includes test CSV matrices and a Newick phylogenetic tree for validating the distance computation pipeline.
183 lines
7.1 KiB
Python
Executable File
183 lines
7.1 KiB
Python
Executable File
#!/usr/bin/env python3
|
|
"""Compare all reference distance matrices against obikmer distance outputs.
|
|
|
|
Reads from:
|
|
reference_dist/ — ground-truth matrices computed by build_reference_dist.py
|
|
obikmer_dist/ — matrices produced by `obikmer distance`
|
|
|
|
Handles label reordering: both matrices are sorted by genome label before
|
|
element-wise comparison, so column/row order differences are irrelevant.
|
|
|
|
Output: stats/dist_comparison/summary.csv
|
|
comparison,max_abs,mean_abs,rmse,n_pairs,status
|
|
"""
|
|
import csv
|
|
import sys
|
|
from pathlib import Path
|
|
|
|
import numpy as np
|
|
|
|
|
|
# ── CSV loading ───────────────────────────────────────────────────────────────
|
|
|
|
def load_matrix(path: Path) -> tuple[list[str], np.ndarray]:
|
|
"""Load a distance-matrix CSV; return (sorted_labels, matrix_float64)."""
|
|
with path.open() as fh:
|
|
reader = csv.reader(fh)
|
|
header = next(reader)[1:] # skip 'genome' column
|
|
raw: dict[str, list[float]] = {}
|
|
for row in reader:
|
|
raw[row[0]] = [float(x) for x in row[1:]]
|
|
|
|
label_to_col = {h: i for i, h in enumerate(header)}
|
|
labels = sorted(raw.keys())
|
|
n = len(labels)
|
|
mat = np.zeros((n, n), dtype=np.float64)
|
|
for i, ri in enumerate(labels):
|
|
for j, cj in enumerate(labels):
|
|
mat[i, j] = raw[ri][label_to_col[cj]]
|
|
return labels, mat
|
|
|
|
|
|
# ── comparison ────────────────────────────────────────────────────────────────
|
|
|
|
def compare(label: str,
|
|
ref_path: Path,
|
|
obi_path: Path,
|
|
tol: float = 1e-4) -> dict:
|
|
if not ref_path.exists():
|
|
return {'comparison': label, 'status': 'REF_MISSING',
|
|
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
|
|
if not obi_path.exists():
|
|
return {'comparison': label, 'status': 'OBI_MISSING',
|
|
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
|
|
|
|
ref_labels, ref_mat = load_matrix(ref_path)
|
|
obi_labels, obi_mat = load_matrix(obi_path)
|
|
|
|
if ref_labels != obi_labels:
|
|
only_ref = sorted(set(ref_labels) - set(obi_labels))
|
|
only_obi = sorted(set(obi_labels) - set(ref_labels))
|
|
print(f' [{label}] label mismatch — '
|
|
f'only_ref={only_ref} only_obi={only_obi}', file=sys.stderr)
|
|
return {'comparison': label, 'status': 'LABEL_MISMATCH',
|
|
'max_abs': '', 'mean_abs': '', 'rmse': '', 'n_pairs': ''}
|
|
|
|
n = len(ref_labels)
|
|
# Off-diagonal mask
|
|
mask = ~np.eye(n, dtype=bool)
|
|
diff = np.abs(ref_mat[mask] - obi_mat[mask])
|
|
n_pairs = diff.size
|
|
|
|
max_abs = float(diff.max())
|
|
mean_abs = float(diff.mean())
|
|
rmse = float(np.sqrt((diff ** 2).mean()))
|
|
status = 'PASS' if max_abs <= tol else 'FAIL'
|
|
|
|
print(f' [{label}] n={n_pairs} '
|
|
f'max={max_abs:.3e} mean={mean_abs:.3e} rmse={rmse:.3e} {status}',
|
|
file=sys.stderr)
|
|
return {
|
|
'comparison': label,
|
|
'max_abs': f'{max_abs:.6e}',
|
|
'mean_abs': f'{mean_abs:.6e}',
|
|
'rmse': f'{rmse:.6e}',
|
|
'n_pairs': str(n_pairs),
|
|
'status': status,
|
|
}
|
|
|
|
|
|
# ── comparison table ──────────────────────────────────────────────────────────
|
|
|
|
# (label, ref_csv, obikmer_csv)
|
|
# The reference jaccard/shared is presence-based, which should match both
|
|
# presence/jaccard and count/jaccard (threshold=1).
|
|
COMPARISONS = [
|
|
# ── presence index ────────────────────────────────────────────────────────
|
|
('presence/jaccard_dist',
|
|
'reference_dist/jaccard_dist.csv',
|
|
'obikmer_dist/presence/jaccard_dist.csv'),
|
|
|
|
('presence/jaccard_shared',
|
|
'reference_dist/shared_kmers.csv',
|
|
'obikmer_dist/presence/jaccard_shared.csv'),
|
|
|
|
('presence/hamming_dist',
|
|
'reference_dist/hamming_dist.csv',
|
|
'obikmer_dist/presence/hamming_dist.csv'),
|
|
|
|
# ── count index (jaccard cross-check) ─────────────────────────────────────
|
|
('count/jaccard_dist',
|
|
'reference_dist/jaccard_dist.csv',
|
|
'obikmer_dist/count/jaccard_dist.csv'),
|
|
|
|
('count/jaccard_shared',
|
|
'reference_dist/shared_kmers.csv',
|
|
'obikmer_dist/count/jaccard_shared.csv'),
|
|
|
|
# ── count index (count-based metrics) ────────────────────────────────────
|
|
('count/bray_curtis_dist',
|
|
'reference_dist/bray_curtis_dist.csv',
|
|
'obikmer_dist/count/bray_curtis_dist.csv'),
|
|
|
|
('count/relfreq_bray_curtis_dist',
|
|
'reference_dist/relfreq_bray_curtis_dist.csv',
|
|
'obikmer_dist/count/relfreq_bray_curtis_dist.csv'),
|
|
|
|
('count/euclidean_dist',
|
|
'reference_dist/euclidean_dist.csv',
|
|
'obikmer_dist/count/euclidean_dist.csv'),
|
|
|
|
('count/relfreq_euclidean_dist',
|
|
'reference_dist/relfreq_euclidean_dist.csv',
|
|
'obikmer_dist/count/relfreq_euclidean_dist.csv'),
|
|
|
|
('count/hellinger_dist',
|
|
'reference_dist/hellinger_dist.csv',
|
|
'obikmer_dist/count/hellinger_dist.csv'),
|
|
|
|
('count/hellinger_euclidean_dist',
|
|
'reference_dist/hellinger_euclidean_dist.csv',
|
|
'obikmer_dist/count/hellinger_euclidean_dist.csv'),
|
|
]
|
|
|
|
|
|
# ── main ─────────────────────────────────────────────────────────────────────
|
|
|
|
def main() -> None:
|
|
import argparse
|
|
ap = argparse.ArgumentParser(description=__doc__,
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
|
ap.add_argument('--tol', type=float, default=1e-4,
|
|
help='Max abs diff threshold for PASS/FAIL (default 1e-4)')
|
|
ap.add_argument('--out', default='stats/dist_comparison/summary.csv',
|
|
help='Output summary CSV path')
|
|
args = ap.parse_args()
|
|
|
|
out_path = Path(args.out)
|
|
out_path.parent.mkdir(parents=True, exist_ok=True)
|
|
|
|
print(f'Comparing {len(COMPARISONS)} matrix pairs…', file=sys.stderr)
|
|
rows = []
|
|
for label, ref, obi in COMPARISONS:
|
|
rows.append(compare(label, Path(ref), Path(obi), tol=args.tol))
|
|
|
|
fields = ['comparison', 'max_abs', 'mean_abs', 'rmse', 'n_pairs', 'status']
|
|
with out_path.open('w', newline='') as fh:
|
|
w = csv.DictWriter(fh, fieldnames=fields)
|
|
w.writeheader()
|
|
w.writerows(rows)
|
|
|
|
print(f'\n→ {out_path}', file=sys.stderr)
|
|
|
|
n_fail = sum(1 for r in rows if r.get('status') == 'FAIL')
|
|
n_pass = sum(1 for r in rows if r.get('status') == 'PASS')
|
|
print(f'Summary: {n_pass} PASS {n_fail} FAIL '
|
|
f'{len(rows) - n_pass - n_fail} SKIP', file=sys.stderr)
|
|
if n_fail:
|
|
sys.exit(1)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|