diff options
author | Nicholas Noll <nbnoll@eml.cc> | 2020-09-15 13:38:10 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-09-15 13:38:10 -0700 |
commit | 3160f41ee1656d7a4764731e810213a9d52f623b (patch) | |
tree | 32c9b5648aba62c0c12530f7596660d3930eab2e | |
parent | a5f38a5350fbc876ddfb9a8fb041d3937083c4c9 (diff) | |
parent | 4cb96734405fcebb363a72b7b651eec6d9c83cca (diff) |
Fuzzy matching of anchors
-rw-r--r-- | Makefile | 4 | ||||
-rw-r--r-- | pangraph/block.py | 16 | ||||
-rw-r--r-- | pangraph/build.py | 30 | ||||
-rw-r--r-- | pangraph/graph.py | 175 | ||||
-rw-r--r-- | pangraph/sequence.py | 78 | ||||
-rw-r--r-- | pangraph/tree.py | 13 | ||||
-rw-r--r-- | pangraph/utils.py | 40 | ||||
-rwxr-xr-x | scripts/filter_plasmids.py | 57 | ||||
-rwxr-xr-x | scripts/parse_log.py | 113 |
9 files changed, 480 insertions, 46 deletions
@@ -55,8 +55,8 @@ staph: @echo "cluster staph"; \ pangraph cluster -d data/staph data/staph/assemblies/*.fna.gz @echo "build staph"; \ - pangraph build -d data/staph -m 500 -b 0 data/staph/guide.json # 1>data/staph/pangraph.json - + pangraph build -d data/staph -m 500 -b 0 -e 2500 -w 1000 data/staph/guide.json +# 2>staph-e2500-w1000.err 1>staph-e2500-w1000.log # figures # figs/figure1.png: $(STAT) diff --git a/pangraph/block.py b/pangraph/block.py index 05b7806..92f9107 100644 --- a/pangraph/block.py +++ b/pangraph/block.py @@ -26,6 +26,7 @@ class Block(object): super(Block, self).__init__() self.id = randomid() if gen else 0 self.seq = None + self.pos = {} self.muts = {} def __str__(self): @@ -49,6 +50,10 @@ class Block(object): def isolates(self): return dict(Counter([k[0] for k in self.muts])) + @property + def positions(self): + return { tag:(pos, pos+self.len_of(*tag)) for tag, pos in self.pos.items() } + # ------------------ # static methods @@ -56,6 +61,7 @@ class Block(object): def from_seq(cls, name, seq): new_blk = cls() new_blk.seq = as_array(seq) + new_blk.pos = {(name, 0): 0} new_blk.muts = {(name, 0):{}} return new_blk @@ -69,6 +75,7 @@ class Block(object): B = Block() B.id = d['id'] B.seq = as_array(d['seq']) + B.pos = {unpack(k):tuple(v) for k, v in d['pos'].items()} B.muts = {unpack(k):v for k, v in d['muts'].items()} return B @@ -85,6 +92,7 @@ class Block(object): for s in nblk.muts: nblk.muts[s].update({p+offset:c for p,c in b.muts[s].items()}) offset += len(b) + nblk.pos = { k:v for k,v in blks[0].pos.items() } return nblk @@ -138,9 +146,11 @@ class Block(object): qryblks = [nb for i, nb in enumerate(newblks) if qrys[i] is not None] if aln['orientation'] == -1: qryblks = qryblks[::-1] - refblks = [nb for i, nb in enumerate(newblks) if refs[i] is not None] - return newblks, qryblks, refblks, isomap + refblks = [nb for i, nb in enumerate(newblks) if refs[i] is not None] + sharedblks = [nb for i, nb in enumerate(newblks) if refs[i] is not None and qrys[i] is not None] + + return newblks, qryblks, refblks, sharedblks, isomap # -------------- # methods @@ -222,6 +232,7 @@ class Block(object): return {'id' : self.id, 'seq' : "".join(str(n) for n in self.seq), + 'pos' : {pack(k) : v for k,v in self.pos.items()}, 'muts' : {pack(k) : fix(v) for k, v in self.muts.items()}} def __len__(self): @@ -233,6 +244,7 @@ class Block(object): start = val.start or 0 stop = val.stop or len(self.seq) b.seq = self.seq[start:stop] + b.pos = { iso : start+val.start for iso,start in self.pos.items() } for s, _ in self.muts.items(): b.muts[s] = {p-start:c for p,c in self.muts[s].items() if p>=start and p<stop} return b diff --git a/pangraph/build.py b/pangraph/build.py index 01595b7..a29b4fa 100644 --- a/pangraph/build.py +++ b/pangraph/build.py @@ -37,10 +37,24 @@ def register_args(parser): type=int, default=22, help="energy cost for mutations (used during block merges)") + parser.add_argument("-w", "--window", + metavar="edge window", + type=int, + default=1000, + help="amount of sequence to align from for end repair") + parser.add_argument("-e", "--extend", + metavar="edge extend", + type=int, + default=1000, + help="amount of sequence to extend for end repair") parser.add_argument("-s", "--statistics", default=False, action='store_true', help="boolean flag that toggles whether the graph statistics are computed for intermediate graphs") + parser.add_argument("-n", "--num", + type=int, + default=-1, + help="manually sets the tmp directory number. internal use only.") parser.add_argument("input", type=str, default="-", @@ -63,14 +77,17 @@ def main(args): root = args.dir.rstrip('/') tmp = f"{root}/tmp" - i = 0 - while os.path.isdir(tmp) and i < 32: - i += 1 - tmp = f"{root}/tmp{i:03d}" + if args.num == -1: + i = 0 + while os.path.isdir(tmp) and i < 64: + i += 1 + tmp = f"{root}/tmp{i:03d}" + else: + tmp = f"{root}/tmp{args.num:03d}" mkdir(tmp) log("aligning") - T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.statistics) + T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.window, args.extend, args.statistics) # TODO: when debugging phase is done, remove tmp directory graphs = T.collect() @@ -82,6 +99,7 @@ def main(args): with open(f"{root}/graph_{i:03d}.fa", 'w') as fd: g.write_fasta(fd) - T.write_json(sys.stdout, no_seqs=True) + # NOTE: uncomment when done debugging + # T.write_json(sys.stdout, no_seqs=True) return 0 diff --git a/pangraph/graph.py b/pangraph/graph.py index bc1e9d9..9f00587 100644 --- a/pangraph/graph.py +++ b/pangraph/graph.py @@ -1,15 +1,21 @@ -import os, sys +import io, os, sys import json import numpy as np import pprint +import subprocess +import tempfile +from io import StringIO from glob import glob from collections import defaultdict, Counter +from itertools import chain -from Bio import SeqIO, Phylo +from Bio import AlignIO, SeqIO, Phylo from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord +from scipy.stats import entropy + from . import suffix from .block import Block from .sequence import Node, Path @@ -18,10 +24,23 @@ from .utils import Strand, as_string, parse_paf, panic, as_record, new_strand # ------------------------------------------------------------------------ # globals -EXTEND = 2500 +# WINDOW = 1000 +# EXTEND = 2500 pp = pprint.PrettyPrinter(indent=4) # ------------------------------------------------------------------------ +# utility + +def alignment_entropy(rdr): + try: + aln = np.array([list(rec) for rec in AlignIO.read(rdr, 'fasta')], np.character).view(np.uint8) + S = sum(entropy(np.bincount(aln[:,i])/aln.shape[0]) for i in range(aln.shape[1])) + return S/aln.shape[1] + except Exception as msg: + print(f"ERROR: {msg}") + return None + +# ------------------------------------------------------------------------ # Junction class # simple struct @@ -136,7 +155,7 @@ class Graph(object): graphs, names = [], [] for name, path in G.seqs.items(): blks = set([b.id for b in path.blocks()]) - gi = [ i for i, g in enumerate(graphs) if overlaps(blks, g)] + gi = [i for i, g in enumerate(graphs) if overlaps(blks, g)] if len(gi) == 0: graphs.append(blks) names.append(set([name])) @@ -159,7 +178,7 @@ class Graph(object): # --------------- # methods - def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False): + def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False, edge_window=1000, edge_extend=2500): from seqanpy import align_global as align # ---------------------------------- @@ -319,19 +338,11 @@ class Graph(object): or not accepted(hit): continue - merged = True - new_blks = self.merge(proc(hit)) + merged = True + self.merge(proc(hit), edge_window, edge_extend) merged_blks.add(hit['ref']['name']) merged_blks.add(hit['qry']['name']) - # for blk in new_blks: - # for iso in blk.isolates: - # path = self.seqs[iso] - # x, n = path.position_of(blk) - # lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path)) - # subpath = path[lb:ub] - # print(subpath, file=sys.stderr) - # breakpoint("stop") self.remove_transitives() for path in self.seqs.values(): @@ -407,9 +418,6 @@ class Graph(object): # TODO: check that isos is constant along the chain for iso in self.blks[c[0][0]].isolates.keys(): self.seqs[iso].merge(c[0], c[-1], new_blk) - # for n in self.seqs[iso].nodes: - # if n.blk.id in [e[0] for e in c]: - # breakpoint("bad deletion") self.blks[new_blk.id] = new_blk for b, _ in c: @@ -421,7 +429,7 @@ class Graph(object): blks.update(path.blocks()) self.blks = {b.id:self.blks[b.id] for b in blks} - def merge(self, hit): + def merge(self, hit, window, extend): old_ref = self.blks[hit['ref']['name']] old_qry = self.blks[hit['qry']['name']] @@ -444,7 +452,7 @@ class Graph(object): "qry_name" : hit["qry"]["name"], "orientation" : hit["orientation"]} - merged_blks, new_qrys, new_refs, blk_map = Block.from_aln(aln) + merged_blks, new_qrys, new_refs, shared_blks, blk_map = Block.from_aln(aln) for merged_blk in merged_blks: self.blks[merged_blk.id] = merged_blk @@ -474,6 +482,133 @@ class Graph(object): new_blocks = [] new_blocks.extend(update(old_ref, new_refs, hit['ref'], Strand.Plus)) new_blocks.extend(update(old_qry, new_qrys, hit['qry'], hit['orientation'])) + + lblks_set_x, rblks_set_x = set(), set() + lblks_set_s, rblks_set_s = set(), set() + first = True + num_seqs = 0 + for tag in shared_blks[0].muts.keys(): + pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks] + strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks] + beg, end = pos[0], pos[-1] + if strand[0] == Strand.Plus: + lwindow = min(window, shared_blks[0].len_of(*tag)) + rwindow = min(window, shared_blks[-1].len_of(*tag)) + + lblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+lwindow] + rblks_x = self.seqs[tag[0]][end[1]-rwindow:end[1]+extend] + + # lblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window] + # rblks_s = self.seqs[tag[0]][end[1]-window:end[1]] + elif strand[0] == Strand.Minus: + lwindow = min(window, shared_blks[-1].len_of(*tag)) + rwindow = min(window, shared_blks[0].len_of(*tag)) + rblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+rwindow] + lblks_x = self.seqs[tag[0]][end[1]-lwindow:end[1]+extend] + + # rblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window] + # lblks_s = self.seqs[tag[0]][end[1]-window:end[1]] + else: + raise ValueError("unrecognized strand polarity") + + if first: + lblks_set_x = set([b.id for b in lblks_x]) + rblks_set_x = set([b.id for b in rblks_x]) + + # lblks_set_s = set([b.id for b in lblks_s]) + # rblks_set_s = set([b.id for b in rblks_s]) + + lblks_set_s = set([b.id for b in lblks_x]) + rblks_set_s = set([b.id for b in rblks_x]) + + first = False + else: + lblks_set_x.intersection_update(set([b.id for b in lblks_x])) + rblks_set_x.intersection_update(set([b.id for b in rblks_x])) + + lblks_set_s.update(set([b.id for b in lblks_x])) + rblks_set_s.update(set([b.id for b in rblks_x])) + # lblks_set_s.intersection_update(set([b.id for b in lblks_s])) + # rblks_set_s.intersection_update(set([b.id for b in rblks_s])) + num_seqs += 1 + + def emit(side): + if side == 'left': + delta = len(lblks_set_s)-len(lblks_set_x) + elif side == 'right': + delta = len(lblks_set_s)-len(rblks_set_x) + else: + raise ValueError(f"unrecognized argument '{side}' for side") + + if delta > 0 and num_seqs > 1: + print(f">LEN={delta}", end=';') + try: + fd, path = tempfile.mkstemp() + with os.fdopen(fd, 'w') as tmp: + for i, tag in enumerate(merged_blks[0].muts.keys()): + pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks] + strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks] + beg, end = pos[0], pos[-1] + + if strand[0] == Strand.Plus: + if side == 'left': + left, right = beg[0]-extend,beg[0]+min(window,shared_blks[0].len_of(*tag)) + elif side == 'right': + left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend + else: + raise ValueError(f"unrecognized argument '{side}' for side") + + elif strand[0] == Strand.Minus: + if side == 'left': + left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend + elif side == 'right': + left, right = beg[0]-extend,beg[0]+min(window, shared_blks[0].len_of(*tag)) + else: + raise ValueError(f"unrecognized argument '{side}' for side") + + iso_blks = self.seqs[tag[0]][left:right] + # print("POSITIONS", pos) + # print("STRAND", strand) + # print("LIST", shared_blks) + # print("MERGED", merged_blks) + # print("INTERSECTION", lblks_set_x if side == 'left' else rblks_set_x) + # print("UNION", lblks_set_s if side == 'left' else rblks_set_s) + # print("ISO", iso_blks) + # breakpoint("stop") + tmp.write(f">isolate_{i:04d} {','.join(b.id for b in iso_blks)}\n") + s = self.seqs[tag[0]].sequence_range(left,right) + if len(s) > extend + window: + breakpoint(f"bad sequence slicing: {len(s)}") + tmp.write(s + '\n') + + tmp.flush() + + proc = subprocess.Popen(f"mafft --auto {path}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) + # proc[1] = subprocess.Popen(f"fasttree", + # stdin =subprocess.PIPE, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + # shell=True) + out, err = proc.communicate() + # out[1], err[1] = proc[1].communicate(input=out[0]) + # tree = Phylo.read(io.StringIO(out[1].decode('utf-8')), format='newick') + print(f"ALIGNMENT={out}", end=";") + rdr = StringIO(out.decode('utf-8')) + print(f"SCORE={alignment_entropy(rdr)}", end=";") + rdr.close() + # print(f"SCORE={tree.total_branch_length()/(2*num_seqs)}", end=";") + print("\n", end="") + finally: + os.remove(path) + else: + print(f">NO MATCH") + + emit('left') + emit('right') + self.prune_blks() return [b[0] for b in new_blocks] diff --git a/pangraph/sequence.py b/pangraph/sequence.py index b2b5ce7..db2b3e6 100644 --- a/pangraph/sequence.py +++ b/pangraph/sequence.py @@ -88,12 +88,6 @@ class Path(object): return seq - def position_of(self, blk): - for i, n in enumerate(self.nodes): - if n.blk == blk: - return i, n.num - raise ValueError("block not found in path") - def rm_nil_blks(self): good, popped = [], set() for i, n in enumerate(self.nodes): @@ -121,6 +115,8 @@ class Path(object): try: i, j = ids.index(start[0]), ids.index(stop[0]) + if N > 0: + breakpoint("HIT") if self.nodes[i].strand == start[1]: beg, end, s = i, j, Strand.Plus else: @@ -153,17 +149,75 @@ class Path(object): self.nodes = new self.position = np.cumsum([0] + [n.length(self.name) for n in self.nodes]) + def position_of(self, blk, num): + index = { n.num:i for i, n in enumerate(self.nodes) if n.blk == blk } + if not num in index: + return None + return (self.position[index[num]], self.position[index[num]+1]) + + def orientation_of(self, blk, num): + orientation = { n.num:n.strand for i, n in enumerate(self.nodes) if n.blk == blk } + if not num in orientation: + return None + return orientation[num] + + # TODO: pull out common functionality into a helper function + # TODO: merge with other sequence function + def sequence_range(self, start=None, stop=None): + beg = start or 0 + end = stop or self.position[-1] + l, r = "", "" + if beg < 0: + if len(self.nodes) > 1: + l = self.sequence_range(self.position[-1]+beg,self.position[-1]) + beg = 0 + if end > self.position[-1]: + if len(self.nodes) > 1: + r = self.sequence_range(0,end-self.position[-1]) + end = self.position[-1] + if beg > end: + beg, end = end, beg + + i = np.searchsorted(self.position, beg, side='right') - 1 + j = np.searchsorted(self.position, end, side='left') + m = "" + if i < j: + if j >= len(self.position): + breakpoint("what?") + if i == j - 1: + m = self.nodes[i].blk.extract(self.name, self.nodes[i].num)[(beg-self.position[i]):(end-self.position[i])] + else: + m = self.nodes[i].blk.extract(self.name, self.nodes[i].num)[(beg-self.position[i]):] + for n in self.nodes[i+1:j-1]: + m += n.blk.extract(self.name, n.num) + n = self.nodes[j-1] + b = n.blk.extract(self.name, n.num) + m += b[0:(end-self.position[j-1])] + return l + m + r + def __getitem__(self, index): if isinstance(index, slice): beg = index.start or 0 end = index.stop or self.position[-1] - - i = np.searchsorted(self.position, beg, side='right') - j = np.searchsorted(self.position, end, side='right') + 1 - assert i < j, "sorted" - return [n.blk for n in self.nodes[i:j]] + l, r = [], [] + if beg < 0: + if len(self.nodes) > 1: + l = self[(self.position[-1]+beg):self.position[-1]] + beg = 0 + if end > self.position[-1]: + if len(self.nodes) > 1: + r = self[0:(end-self.position[-1])] + end = self.position[-1] + if beg > end: + beg, end = end, beg + + i = np.searchsorted(self.position, beg, side='right') - 1 + j = np.searchsorted(self.position, end, side='left') + if i > j: + breakpoint(f"not sorted, {beg}-{end}") + return l + [n.blk for n in self.nodes[i:j]] + r elif isinstance(index, int): - i = np.searchsorted(self.position, index, side='left') + i = np.searchsorted(self.position, index, side='right') - 1 return self.nodes[i].blk else: raise ValueError(f"type '{type(index)}' not supported as index") diff --git a/pangraph/tree.py b/pangraph/tree.py index 605bac3..d36b77d 100644 --- a/pangraph/tree.py +++ b/pangraph/tree.py @@ -139,6 +139,11 @@ class Clade(object): 'fapath' : self.fapath, 'graph' : serialize(self.graph) if self.graph is not None else None } + def set_level(self, level): + for c in self.child: + c.set_level(level+1) + self.level = level + class Tree(object): # ------------------- # Class constructor @@ -287,7 +292,8 @@ class Tree(object): leafs = {n.name: n for n in self.get_leafs()} self.seqs = {leafs[name]:seq for name,seq in seqs.items()} - def align(self, tmpdir, min_blk_len, mu, beta, extensive, log_stats=False, verbose=False): + def align(self, tmpdir, min_blk_len, mu, beta, extensive, edge_window, edge_extend, log_stats=False, verbose=False): + self.root.set_level(0) # NOTE: for debug logging stats = {} # --------------------------------------------- # internal functions @@ -345,7 +351,7 @@ class Tree(object): graph1, fapath1 = node1.graph, node1.fapath graph2, fapath2 = node2.graph, node2.fapath graph = Graph.fuse(graph1, graph2) - graph, _ = graph.union(fapath1, fapath2, f"{tmpdir}/{n.name}", min_blk_len, mu, beta, extensive) + graph, _ = graph.union(fapath1, fapath2, f"{tmpdir}/{n.name}", min_blk_len, mu, beta, extensive, edge_window, edge_extend) else: graph = node1.graph @@ -355,7 +361,7 @@ class Tree(object): itr = f"{tmpdir}/{n.name}_iter_{i}" with open(f"{itr}.fa", 'w') as fd: graph.write_fasta(fd) - graph, contin = graph.union(itr, itr, f"{tmpdir}/{n.name}_iter_{i}", min_blk_len, mu, beta, extensive) + graph, contin = graph.union(itr, itr, f"{tmpdir}/{n.name}_iter_{i}", min_blk_len, mu, beta, extensive, edge_window, edge_extend) if not contin: return graph return graph @@ -377,6 +383,7 @@ class Tree(object): for n in self.postorder(): if n.is_leaf(): continue + print(f"+++LEVEL={n.level}+++") n.fapath = f"{tmpdir}/{n.name}" log(f"fusing {n.child[0].name} with {n.child[1].name} @ {n.name}") n.graph = merge(*n.child) diff --git a/pangraph/utils.py b/pangraph/utils.py index e0d85fa..e4b63f3 100644 --- a/pangraph/utils.py +++ b/pangraph/utils.py @@ -3,6 +3,7 @@ import csv import gzip import numpy as np +from io import StringIO from enum import IntEnum from Bio import SeqIO @@ -97,7 +98,10 @@ def as_array(x): return np.array(list(x)) def as_string(x): - return x.view(f'U{x.size}')[0] + try: + return x.view(f'U{x.size}')[0] + except: + return "".join(str(c) for c in x) def flatten(x): return np.ndarray.flatten(x[:]) @@ -154,9 +158,43 @@ def getnwk(node, newick, parentdist, leaf_names): newick = "(%s" % (newick) return newick +def as_str(s): + if isinstance(s, bytes): + return s.decode('utf-8') + return s + # ------------------------------------------------------------------------ # parsers +def parse_fasta(fh): + class Record: + def __init__(self, name=None, meta=None, seq=None): + self.seq = seq + self.name = name + self.meta = meta + + def __str__(self): + NL = '\n' + nc = 80 + return f">{self.name} {self.meta}\n{NL.join([self.seq[i:(i+nc)] for i in range(0, len(self.seq), nc)])}" + + def __repr__(self): + return str(self) + + header = as_str(fh.readline()) + while header != "" and header[0] == ">": + name = header[1:].split() + seq = StringIO() + for line in fh: + line = as_str(line) + if line == "" or line[0] == ">": + break + seq.write(line[:-1]) + + header = as_str(line) + yield Record(name=name[0], meta=" ".join(name[1:]), seq=seq.getvalue()) + seq.close() + def parse_paf(fh): hits = [] for line in fh: diff --git a/scripts/filter_plasmids.py b/scripts/filter_plasmids.py new file mode 100755 index 0000000..c309150 --- /dev/null +++ b/scripts/filter_plasmids.py @@ -0,0 +1,57 @@ +#!/usr/bin/env python3 +""" +script to filter plasmids and chromosomes from full genome assemblies +""" + +import os +import sys +import gzip +import builtins +import argparse + +from glob import glob + +sys.path.insert(0, os.path.abspath('.')) # gross hack +from pangraph.utils import parse_fasta, breakpoint + +def open(path, *args, **kwargs): + if path.endswith('.gz'): + return gzip.open(path, *args, **kwargs) + else: + return builtins.open(path, *args, **kwargs) + +def main(dirs, plasmids=True): + for d in dirs: + in_dir = f"data/{d}/assemblies" + if not os.path.exists(in_dir): + print(f"{in_dir} doesn't exist. skipping...") + continue + + if plasmids: + out_dir = f"data/{d}-plasmid/assemblies" + else: + out_dir = f"data/{d}-chromosome/assemblies" + + if not os.path.exists(out_dir): + os.makedirs(out_dir) + + for path in glob(f"{in_dir}/*.f?a*"): + with open(path, 'rt') as fd, open(f"{out_dir}/{os.path.basename(path).replace('.gz', '')}", 'w') as wtr: + for i, rec in enumerate(parse_fasta(fd)): + if i == 0: + if not plasmids: + wtr.write(str(rec)) + wtr.write('\n') + break + continue + + wtr.write(str(rec)) + wtr.write('\n') + +parser = argparse.ArgumentParser(description='seperate plasmids from chromosomes') +parser.add_argument('directories', metavar='dirs', nargs='+') +parser.add_argument('--chromosomes', default=False, action='store_true') + +if __name__ == "__main__": + args = parser.parse_args() + main(args.directories, plasmids=not args.chromosomes) diff --git a/scripts/parse_log.py b/scripts/parse_log.py new file mode 100755 index 0000000..65ef8fd --- /dev/null +++ b/scripts/parse_log.py @@ -0,0 +1,113 @@ +#!/usr/bin/env python3 +""" +script to process our end repair log files for plotting +""" +import ast +import argparse + +from io import StringIO +from Bio import AlignIO +from collections import defaultdict + +import numpy as np +import matplotlib.pylab as plt + +level_preset = "+++LEVEL=" +level_offset = len(level_preset) + +score_preset = "SCORE=" +score_offset = len(score_preset) + +msaln_preset = "ALIGNMENT=" +msaln_offset = len(msaln_preset) + +def unpack(line): + offset = [line.find(";")] + offset.append(line.find(";", offset[0]+1)) + offset.append(line.find(";", offset[1]+1)) + + align = StringIO(ast.literal_eval(line[msaln_offset+1+offset[0]:offset[1]]).decode('utf-8')) + align = next(AlignIO.parse(align, 'fasta')) + score = float(line[offset[1]+1+score_offset:offset[2]]) + + return align, score + +# TODO: remove hardcoded numbers +def save_aln_examples(results): + stats = results[(500,1000)] + nums = [0, 0, 0] + for score, aln in stats[1]['hits']: + if len(aln) < 10: + continue + + if score < 1e-4: + with open(f"scratch/1/eg_{nums[0]}.fna", 'w+') as fd: + AlignIO.write(aln, fd, "fasta") + nums[0] += 1 + elif score < 1e-2: + with open(f"scratch/2/eg_{nums[1]}.fna", 'w+') as fd: + AlignIO.write(aln, fd, "fasta") + nums[1] += 1 + else: + with open(f"scratch/3/eg_{nums[2]}.fna", 'w+') as fd: + AlignIO.write(aln, fd, "fasta") + nums[2] += 1 + +def main(args): + results = {} + for log_path in args: + stats = defaultdict(lambda: {'hits':[], 'miss': 0}) + level = -1 + with open(log_path) as log: + for line in log: + line.rstrip('\n') + if line[0] == "+": + assert line.startswith(level_preset), "check syntax in log file" + level = int(line[level_offset:line.find("+++", level_offset)]) + continue + if line[0] == ">": + if line[1:].startswith("NO MATCH"): + stats[level]['miss'] += 1 + continue + if line[1:].startswith("LEN="): + aln, score = unpack(line) + stats[level]['hits'].append((score, aln)) + continue + raise ValueError(f"invalid syntax: {line[1:]}") + if len(stats) > 0: + path = log_path.replace(".log", "").split("-") + e, w = int(path[1][1:]), int(path[2][1:]) + results[(e,w)] = dict(stats) + return results + +def plot(results): + data = results[(500,1000)] + scores = [[] for _ in data.keys()] + fracs = np.zeros(len(data.keys())) + for i, lvl in enumerate(sorted(data.keys())): + scores[i] = [elt[0] for elt in data[lvl]['hits']] + fracs[i] = len(scores[i])/(len(scores[i]) + data[lvl]['miss']) + + cmap = plt.cm.get_cmap('plasma') + colors = [cmap(x) for x in np.linspace(0,1,len(scores[:-5]))] + fig, (ax1, ax2) = plt.subplots(1, 2) + ax1.plot(fracs[:-5]) + ax1.set_xlabel("tree level") + ax1.set_ylabel("fraction of good edges") + for i, score in enumerate(scores[:-5]): + ax2.plot(sorted(np.exp(score)), np.linspace(0,1,len(score)),color=colors[i],label=f"level={i+1}") + ax2.set_xlabel("average column entropy") + ax2.set_ylabel("CDF") + ax2.legend() + +parser = argparse.ArgumentParser(description='process our data log files on end repair') +parser.add_argument('files', type=str, nargs='+') + +# ----------------------------------------------------------------------------- +# main point of entry + +if __name__ == "__main__": + args = parser.parse_args() + results = main(args.files) + save_aln_examples(results) + plot(results) |