182 lines
7.4 KiB
Python
182 lines
7.4 KiB
Python
|
|
#!/usr/bin/env python3
|
||
|
|
"""Compare an obikmer count index against a reference kmer set (presence + counts).
|
||
|
|
|
||
|
|
Loads the reference .npz (sorted uint64 kmers + uint32 counts from build_reference.py),
|
||
|
|
streams `obikmer dump` from a --with-counts index, then reports:
|
||
|
|
- false negatives : kmers in reference absent from the index
|
||
|
|
- false positives : kmers in the index absent from the reference
|
||
|
|
- count mismatches: kmers present in both but with differing counts
|
||
|
|
|
||
|
|
Output to stdout: one CSV row
|
||
|
|
species,strain,ref_kmers,idx_kmers,false_neg,false_pos,count_mismatch,
|
||
|
|
fn_pct,fp_pct,cm_pct
|
||
|
|
"""
|
||
|
|
import argparse
|
||
|
|
import subprocess
|
||
|
|
import sys
|
||
|
|
|
||
|
|
import numpy as np
|
||
|
|
|
||
|
|
|
||
|
|
# ── encoding ──────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
_ENCODE = {'A': 0, 'C': 1, 'G': 2, 'T': 3,
|
||
|
|
'a': 0, 'c': 1, 'g': 2, 't': 3}
|
||
|
|
|
||
|
|
_DECODE = ['A', 'C', 'G', 'T']
|
||
|
|
|
||
|
|
|
||
|
|
def encode_kmer(s: str) -> int:
|
||
|
|
kmer = 0
|
||
|
|
for c in s:
|
||
|
|
kmer = (kmer << 2) | _ENCODE[c]
|
||
|
|
return kmer
|
||
|
|
|
||
|
|
|
||
|
|
def decode_kmer(val: int, k: int) -> str:
|
||
|
|
bases = []
|
||
|
|
for _ in range(k):
|
||
|
|
bases.append(_DECODE[val & 3])
|
||
|
|
val >>= 2
|
||
|
|
return ''.join(reversed(bases))
|
||
|
|
|
||
|
|
|
||
|
|
# ── dump parsing ──────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def load_index(obikmer_bin: str, index_dir: str) -> tuple[np.ndarray, np.ndarray]:
|
||
|
|
"""Stream `obikmer dump` and return (kmers_sorted_uint64, counts_uint32)."""
|
||
|
|
cmd = [obikmer_bin, 'dump', index_dir]
|
||
|
|
proc = subprocess.Popen(cmd, stdout=subprocess.PIPE, stderr=subprocess.DEVNULL,
|
||
|
|
text=True)
|
||
|
|
kmers, counts = [], []
|
||
|
|
header = True
|
||
|
|
for line in proc.stdout:
|
||
|
|
if header:
|
||
|
|
header = False
|
||
|
|
continue
|
||
|
|
parts = line.rstrip('\n').split(',')
|
||
|
|
kmers.append(encode_kmer(parts[0]))
|
||
|
|
counts.append(int(parts[1]))
|
||
|
|
proc.wait()
|
||
|
|
if proc.returncode != 0:
|
||
|
|
print(f'ERROR: obikmer dump exited {proc.returncode}', file=sys.stderr)
|
||
|
|
sys.exit(1)
|
||
|
|
order = np.argsort(np.array(kmers, dtype=np.uint64), kind='stable')
|
||
|
|
return (np.array(kmers, dtype=np.uint64)[order],
|
||
|
|
np.array(counts, dtype=np.uint32)[order])
|
||
|
|
|
||
|
|
|
||
|
|
# ── comparison ────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def compare(ref_kmers: np.ndarray, ref_counts: np.ndarray,
|
||
|
|
idx_kmers: np.ndarray, idx_counts: np.ndarray,
|
||
|
|
) -> tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
|
||
|
|
"""Return (false_neg, false_pos, cm_ref_kmers, cm_ref_counts, cm_idx_counts).
|
||
|
|
|
||
|
|
All arrays sorted; cm_* cover kmers present in both arrays but with
|
||
|
|
differing counts.
|
||
|
|
"""
|
||
|
|
false_neg = np.setdiff1d(ref_kmers, idx_kmers, assume_unique=True)
|
||
|
|
false_pos = np.setdiff1d(idx_kmers, ref_kmers, assume_unique=True)
|
||
|
|
|
||
|
|
# Count mismatches among shared kmers.
|
||
|
|
# Both arrays are sorted so we can use searchsorted.
|
||
|
|
pos_in_idx = np.searchsorted(idx_kmers, ref_kmers)
|
||
|
|
pos_in_idx = np.clip(pos_in_idx, 0, len(idx_kmers) - 1)
|
||
|
|
shared_mask = idx_kmers[pos_in_idx] == ref_kmers
|
||
|
|
|
||
|
|
shared_ref_counts = ref_counts[shared_mask]
|
||
|
|
shared_idx_counts = idx_counts[pos_in_idx[shared_mask]]
|
||
|
|
mismatch_mask = shared_ref_counts != shared_idx_counts
|
||
|
|
|
||
|
|
cm_kmers = ref_kmers[shared_mask][mismatch_mask]
|
||
|
|
cm_ref_counts = shared_ref_counts[mismatch_mask]
|
||
|
|
cm_idx_counts = shared_idx_counts[mismatch_mask]
|
||
|
|
|
||
|
|
return false_neg, false_pos, cm_kmers, cm_ref_counts, cm_idx_counts
|
||
|
|
|
||
|
|
|
||
|
|
# ── main ─────────────────────────────────────────────────────────────────────
|
||
|
|
|
||
|
|
def main() -> None:
|
||
|
|
ap = argparse.ArgumentParser(description=__doc__,
|
||
|
|
formatter_class=argparse.RawDescriptionHelpFormatter)
|
||
|
|
ap.add_argument('reference', metavar='REF_NPZ', nargs='?',
|
||
|
|
help='Reference .npz file')
|
||
|
|
ap.add_argument('index', metavar='INDEX_DIR', nargs='?',
|
||
|
|
help='obikmer index directory (built with --with-counts)')
|
||
|
|
ap.add_argument('--obikmer', default='obikmer',
|
||
|
|
help='Path to obikmer binary')
|
||
|
|
ap.add_argument('--species', default='')
|
||
|
|
ap.add_argument('--strain', default='')
|
||
|
|
ap.add_argument('--header', action='store_true',
|
||
|
|
help='Print CSV header and exit')
|
||
|
|
ap.add_argument('--save-fp', metavar='FILE',
|
||
|
|
help='Save false-positive kmer strings to FILE')
|
||
|
|
ap.add_argument('--save-fn', metavar='FILE',
|
||
|
|
help='Save false-negative kmer strings to FILE')
|
||
|
|
ap.add_argument('--save-cm', metavar='FILE',
|
||
|
|
help='Save count-mismatch rows (kmer,ref_count,idx_count) to FILE')
|
||
|
|
args = ap.parse_args()
|
||
|
|
|
||
|
|
if args.header:
|
||
|
|
print('species,strain,ref_kmers,idx_kmers,'
|
||
|
|
'false_neg,false_pos,count_mismatch,'
|
||
|
|
'fn_pct,fp_pct,cm_pct')
|
||
|
|
return
|
||
|
|
|
||
|
|
# Detect k
|
||
|
|
cmd1 = [args.obikmer, 'dump', '--head', '1', args.index]
|
||
|
|
out1 = subprocess.check_output(cmd1, stderr=subprocess.DEVNULL, text=True)
|
||
|
|
k = len(out1.splitlines()[1].split(',')[0])
|
||
|
|
|
||
|
|
# Load reference
|
||
|
|
print(f'Loading reference: {args.reference}', file=sys.stderr)
|
||
|
|
npz = np.load(args.reference)
|
||
|
|
ref_kmers = npz['kmers'] # sorted uint64
|
||
|
|
ref_counts = npz['counts'] # uint32
|
||
|
|
|
||
|
|
# Load index
|
||
|
|
print(f'Streaming dump (k={k}): {args.index}', file=sys.stderr)
|
||
|
|
idx_kmers, idx_counts = load_index(args.obikmer, args.index)
|
||
|
|
|
||
|
|
print(f'k={k} ref={len(ref_kmers):,} idx={len(idx_kmers):,}', file=sys.stderr)
|
||
|
|
|
||
|
|
false_neg, false_pos, cm_kmers, cm_ref, cm_idx = compare(
|
||
|
|
ref_kmers, ref_counts, idx_kmers, idx_counts)
|
||
|
|
|
||
|
|
n_shared = len(ref_kmers) - len(false_neg)
|
||
|
|
fn_pct = 100.0 * len(false_neg) / len(ref_kmers) if len(ref_kmers) else 0.0
|
||
|
|
fp_pct = 100.0 * len(false_pos) / len(idx_kmers) if len(idx_kmers) else 0.0
|
||
|
|
cm_pct = 100.0 * len(cm_kmers) / n_shared if n_shared else 0.0
|
||
|
|
|
||
|
|
print(f'false negatives : {len(false_neg):,} ({fn_pct:.4f}%)', file=sys.stderr)
|
||
|
|
print(f'false positives : {len(false_pos):,} ({fp_pct:.4f}%)', file=sys.stderr)
|
||
|
|
print(f'count mismatches: {len(cm_kmers):,} ({cm_pct:.4f}% of shared)',
|
||
|
|
file=sys.stderr)
|
||
|
|
|
||
|
|
if args.save_fn and len(false_neg):
|
||
|
|
with open(args.save_fn, 'w') as fh:
|
||
|
|
for v in false_neg:
|
||
|
|
fh.write(decode_kmer(int(v), k) + '\n')
|
||
|
|
|
||
|
|
if args.save_fp and len(false_pos):
|
||
|
|
with open(args.save_fp, 'w') as fh:
|
||
|
|
for v in false_pos:
|
||
|
|
fh.write(decode_kmer(int(v), k) + '\n')
|
||
|
|
|
||
|
|
if args.save_cm and len(cm_kmers):
|
||
|
|
with open(args.save_cm, 'w') as fh:
|
||
|
|
fh.write('kmer,ref_count,idx_count\n')
|
||
|
|
for v, rc, ic in zip(cm_kmers, cm_ref, cm_idx):
|
||
|
|
fh.write(f'{decode_kmer(int(v), k)},{rc},{ic}\n')
|
||
|
|
|
||
|
|
print(f'{args.species},{args.strain},'
|
||
|
|
f'{len(ref_kmers)},{len(idx_kmers)},'
|
||
|
|
f'{len(false_neg)},{len(false_pos)},{len(cm_kmers)},'
|
||
|
|
f'{fn_pct:.4f},{fp_pct:.4f},{cm_pct:.4f}')
|
||
|
|
|
||
|
|
|
||
|
|
if __name__ == '__main__':
|
||
|
|
main()
|