Files

202 lines
7.9 KiB
Python
Raw Permalink Normal View History

#!/usr/bin/env python3
"""Verify the merged count index against all per-specimen reference sets.
Streams `obikmer dump` once on the merged index, accumulates per-specimen
kmer+count pairs from each column, then compares each against its reference .npz.
Output to stdout: one CSV row per specimen (same columns as verify_count.py)
species,strain,ref_kmers,idx_kmers,false_neg,false_pos,count_mismatch,
fn_pct,fp_pct,cm_pct
"""
import argparse
import subprocess
import sys
from pathlib import Path
import numpy as np
# ── encoding ──────────────────────────────────────────────────────────────────
_ENCODE = {'A': 0, 'C': 1, 'G': 2, 'T': 3,
'a': 0, 'c': 1, 'g': 2, 't': 3}
_DECODE = ['A', 'C', 'G', 'T']
def encode_kmer(s: str) -> int:
kmer = 0
for c in s:
kmer = (kmer << 2) | _ENCODE[c]
return kmer
def decode_kmer(val: int, k: int) -> str:
bases = []
for _ in range(k):
bases.append(_DECODE[val & 3])
val >>= 2
return ''.join(reversed(bases))
# ── single-pass dump ──────────────────────────────────────────────────────────
def stream_merged_dump(obikmer_bin: str, index_dir: str,
) -> tuple[list[str], dict[str, tuple[list[int], list[int]]]]:
"""Stream the merged dump once.
Returns:
specimen_names : column labels in dump order
per_specimen : mapping label → (kmer_ints, counts) for entries > 0
"""
cmd = [obikmer_bin, 'dump', index_dir]
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
text=True)
header_line = proc.stdout.readline().rstrip('\n')
cols = header_line.split(',')
specimen_names = cols[1:]
per_specimen: dict[str, tuple[list[int], list[int]]] = {
name: ([], []) for name in specimen_names}
for line in proc.stdout:
parts = line.rstrip('\n').split(',')
kmer_int = encode_kmer(parts[0])
for i, name in enumerate(specimen_names):
count = int(parts[i + 1])
if count > 0:
per_specimen[name][0].append(kmer_int)
per_specimen[name][1].append(count)
proc.wait()
if proc.returncode != 0:
print(f'ERROR: obikmer dump exited {proc.returncode}', file=sys.stderr)
sys.exit(1)
return specimen_names, per_specimen
# ── per-specimen comparison ───────────────────────────────────────────────────
def compare_specimen(name: str,
kmer_list: list[int],
count_list: list[int],
ref_dir: Path,
k: int,
save_fn: Path | None,
save_fp: Path | None,
save_cm: Path | None,
) -> str:
ref_path = ref_dir / f'{name}.npz'
if not ref_path.exists():
print(f' SKIP {name}: no reference at {ref_path}', file=sys.stderr)
return ''
species = name.split('--')[0]
strain = name[len(species) + 2:]
npz = np.load(ref_path)
ref_kmers = npz['kmers'] # sorted uint64
ref_counts = npz['counts'] # uint32
order = np.argsort(np.array(kmer_list, dtype=np.uint64), kind='stable')
idx_kmers = np.array(kmer_list, dtype=np.uint64)[order]
idx_counts = np.array(count_list, dtype=np.uint32)[order]
false_neg = np.setdiff1d(ref_kmers, idx_kmers, assume_unique=True)
false_pos = np.setdiff1d(idx_kmers, ref_kmers, assume_unique=True)
# Count mismatches among shared kmers
pos_in_idx = np.searchsorted(idx_kmers, ref_kmers)
pos_in_idx = np.clip(pos_in_idx, 0, len(idx_kmers) - 1)
shared_mask = idx_kmers[pos_in_idx] == ref_kmers
mismatch_mask = ref_counts[shared_mask] != idx_counts[pos_in_idx[shared_mask]]
cm_kmers = ref_kmers[shared_mask][mismatch_mask]
cm_ref = ref_counts[shared_mask][mismatch_mask]
cm_idx = idx_counts[pos_in_idx[shared_mask]][mismatch_mask]
n_shared = int(shared_mask.sum())
fn_pct = 100.0 * len(false_neg) / len(ref_kmers) if len(ref_kmers) else 0.0
fp_pct = 100.0 * len(false_pos) / len(idx_kmers) if len(idx_kmers) else 0.0
cm_pct = 100.0 * len(cm_kmers) / n_shared if n_shared else 0.0
print(f' {name}: ref={len(ref_kmers):,} idx={len(idx_kmers):,} '
f'fn={len(false_neg):,} ({fn_pct:.4f}%) '
f'fp={len(false_pos):,} ({fp_pct:.4f}%) '
f'cm={len(cm_kmers):,} ({cm_pct:.4f}%)',
file=sys.stderr)
if save_fn and len(false_neg):
fn_file = save_fn / f'{name}_fn.txt'
fn_file.write_text('\n'.join(decode_kmer(int(v), k) for v in false_neg) + '\n')
if save_fp and len(false_pos):
fp_file = save_fp / f'{name}_fp.txt'
fp_file.write_text('\n'.join(decode_kmer(int(v), k) for v in false_pos) + '\n')
if save_cm and len(cm_kmers):
cm_file = save_cm / f'{name}_cm.csv'
lines = ['kmer,ref_count,idx_count']
for v, rc, ic in zip(cm_kmers, cm_ref, cm_idx):
lines.append(f'{decode_kmer(int(v), k)},{rc},{ic}')
cm_file.write_text('\n'.join(lines) + '\n')
return (f'{species},{strain},'
f'{len(ref_kmers)},{len(idx_kmers)},'
f'{len(false_neg)},{len(false_pos)},{len(cm_kmers)},'
f'{fn_pct:.4f},{fp_pct:.4f},{cm_pct:.4f}')
# ── main ─────────────────────────────────────────────────────────────────────
def main() -> None:
ap = argparse.ArgumentParser(description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter)
ap.add_argument('index', metavar='INDEX_DIR', nargs='?',
help='Merged count index directory')
ap.add_argument('ref_dir', metavar='REF_DIR', nargs='?',
help='Directory containing per-specimen .npz reference files')
ap.add_argument('--obikmer', default='obikmer')
ap.add_argument('--header', action='store_true',
help='Print CSV header and exit')
ap.add_argument('--save-fn', metavar='DIR',
help='Directory for false-negative kmer lists')
ap.add_argument('--save-fp', metavar='DIR',
help='Directory for false-positive kmer lists')
ap.add_argument('--save-cm', metavar='DIR',
help='Directory for count-mismatch CSV files')
args = ap.parse_args()
if args.header:
print('species,strain,ref_kmers,idx_kmers,'
'false_neg,false_pos,count_mismatch,'
'fn_pct,fp_pct,cm_pct')
return
ref_dir = Path(args.ref_dir)
save_fn = Path(args.save_fn) if args.save_fn else None
save_fp = Path(args.save_fp) if args.save_fp else None
save_cm = Path(args.save_cm) if args.save_cm else None
for d in (save_fn, save_fp, save_cm):
if d: d.mkdir(parents=True, exist_ok=True)
out1 = subprocess.check_output(
[args.obikmer, 'dump', '--head', '1', args.index],
stderr=subprocess.DEVNULL, text=True)
k = len(out1.splitlines()[1].split(',')[0])
print(f'k={k} streaming merged dump: {args.index}', file=sys.stderr)
specimen_names, per_specimen = stream_merged_dump(args.obikmer, args.index)
print(f'{len(specimen_names)} specimen columns loaded', file=sys.stderr)
for name in specimen_names:
kmers, counts = per_specimen[name]
row = compare_specimen(name, kmers, counts, ref_dir, k,
save_fn, save_fp, save_cm)
if row:
print(row)
if __name__ == '__main__':
main()