aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--Makefile4
-rw-r--r--pangraph/block.py16
-rw-r--r--pangraph/build.py30
-rw-r--r--pangraph/graph.py175
-rw-r--r--pangraph/sequence.py78
-rw-r--r--pangraph/tree.py13
-rw-r--r--pangraph/utils.py40
-rwxr-xr-xscripts/filter_plasmids.py57
-rwxr-xr-xscripts/parse_log.py113
9 files changed, 480 insertions, 46 deletions
diff --git a/Makefile b/Makefile
index 9681f3b..d1a401d 100644
--- a/Makefile
+++ b/Makefile
@@ -55,8 +55,8 @@ staph:
@echo "cluster staph"; \
pangraph cluster -d data/staph data/staph/assemblies/*.fna.gz
@echo "build staph"; \
- pangraph build -d data/staph -m 500 -b 0 data/staph/guide.json # 1>data/staph/pangraph.json
-
+ pangraph build -d data/staph -m 500 -b 0 -e 2500 -w 1000 data/staph/guide.json
+# 2>staph-e2500-w1000.err 1>staph-e2500-w1000.log
# figures
# figs/figure1.png: $(STAT)
diff --git a/pangraph/block.py b/pangraph/block.py
index 05b7806..92f9107 100644
--- a/pangraph/block.py
+++ b/pangraph/block.py
@@ -26,6 +26,7 @@ class Block(object):
super(Block, self).__init__()
self.id = randomid() if gen else 0
self.seq = None
+ self.pos = {}
self.muts = {}
def __str__(self):
@@ -49,6 +50,10 @@ class Block(object):
def isolates(self):
return dict(Counter([k[0] for k in self.muts]))
+ @property
+ def positions(self):
+ return { tag:(pos, pos+self.len_of(*tag)) for tag, pos in self.pos.items() }
+
# ------------------
# static methods
@@ -56,6 +61,7 @@ class Block(object):
def from_seq(cls, name, seq):
new_blk = cls()
new_blk.seq = as_array(seq)
+ new_blk.pos = {(name, 0): 0}
new_blk.muts = {(name, 0):{}}
return new_blk
@@ -69,6 +75,7 @@ class Block(object):
B = Block()
B.id = d['id']
B.seq = as_array(d['seq'])
+ B.pos = {unpack(k):tuple(v) for k, v in d['pos'].items()}
B.muts = {unpack(k):v for k, v in d['muts'].items()}
return B
@@ -85,6 +92,7 @@ class Block(object):
for s in nblk.muts:
nblk.muts[s].update({p+offset:c for p,c in b.muts[s].items()})
offset += len(b)
+ nblk.pos = { k:v for k,v in blks[0].pos.items() }
return nblk
@@ -138,9 +146,11 @@ class Block(object):
qryblks = [nb for i, nb in enumerate(newblks) if qrys[i] is not None]
if aln['orientation'] == -1:
qryblks = qryblks[::-1]
- refblks = [nb for i, nb in enumerate(newblks) if refs[i] is not None]
- return newblks, qryblks, refblks, isomap
+ refblks = [nb for i, nb in enumerate(newblks) if refs[i] is not None]
+ sharedblks = [nb for i, nb in enumerate(newblks) if refs[i] is not None and qrys[i] is not None]
+
+ return newblks, qryblks, refblks, sharedblks, isomap
# --------------
# methods
@@ -222,6 +232,7 @@ class Block(object):
return {'id' : self.id,
'seq' : "".join(str(n) for n in self.seq),
+ 'pos' : {pack(k) : v for k,v in self.pos.items()},
'muts' : {pack(k) : fix(v) for k, v in self.muts.items()}}
def __len__(self):
@@ -233,6 +244,7 @@ class Block(object):
start = val.start or 0
stop = val.stop or len(self.seq)
b.seq = self.seq[start:stop]
+ b.pos = { iso : start+val.start for iso,start in self.pos.items() }
for s, _ in self.muts.items():
b.muts[s] = {p-start:c for p,c in self.muts[s].items() if p>=start and p<stop}
return b
diff --git a/pangraph/build.py b/pangraph/build.py
index 01595b7..a29b4fa 100644
--- a/pangraph/build.py
+++ b/pangraph/build.py
@@ -37,10 +37,24 @@ def register_args(parser):
type=int,
default=22,
help="energy cost for mutations (used during block merges)")
+ parser.add_argument("-w", "--window",
+ metavar="edge window",
+ type=int,
+ default=1000,
+ help="amount of sequence to align from for end repair")
+ parser.add_argument("-e", "--extend",
+ metavar="edge extend",
+ type=int,
+ default=1000,
+ help="amount of sequence to extend for end repair")
parser.add_argument("-s", "--statistics",
default=False,
action='store_true',
help="boolean flag that toggles whether the graph statistics are computed for intermediate graphs")
+ parser.add_argument("-n", "--num",
+ type=int,
+ default=-1,
+ help="manually sets the tmp directory number. internal use only.")
parser.add_argument("input",
type=str,
default="-",
@@ -63,14 +77,17 @@ def main(args):
root = args.dir.rstrip('/')
tmp = f"{root}/tmp"
- i = 0
- while os.path.isdir(tmp) and i < 32:
- i += 1
- tmp = f"{root}/tmp{i:03d}"
+ if args.num == -1:
+ i = 0
+ while os.path.isdir(tmp) and i < 64:
+ i += 1
+ tmp = f"{root}/tmp{i:03d}"
+ else:
+ tmp = f"{root}/tmp{args.num:03d}"
mkdir(tmp)
log("aligning")
- T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.statistics)
+ T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.window, args.extend, args.statistics)
# TODO: when debugging phase is done, remove tmp directory
graphs = T.collect()
@@ -82,6 +99,7 @@ def main(args):
with open(f"{root}/graph_{i:03d}.fa", 'w') as fd:
g.write_fasta(fd)
- T.write_json(sys.stdout, no_seqs=True)
+ # NOTE: uncomment when done debugging
+ # T.write_json(sys.stdout, no_seqs=True)
return 0
diff --git a/pangraph/graph.py b/pangraph/graph.py
index bc1e9d9..9f00587 100644
--- a/pangraph/graph.py
+++ b/pangraph/graph.py
@@ -1,15 +1,21 @@
-import os, sys
+import io, os, sys
import json
import numpy as np
import pprint
+import subprocess
+import tempfile
+from io import StringIO
from glob import glob
from collections import defaultdict, Counter
+from itertools import chain
-from Bio import SeqIO, Phylo
+from Bio import AlignIO, SeqIO, Phylo
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
+from scipy.stats import entropy
+
from . import suffix
from .block import Block
from .sequence import Node, Path
@@ -18,10 +24,23 @@ from .utils import Strand, as_string, parse_paf, panic, as_record, new_strand
# ------------------------------------------------------------------------
# globals
-EXTEND = 2500
+# WINDOW = 1000
+# EXTEND = 2500
pp = pprint.PrettyPrinter(indent=4)
# ------------------------------------------------------------------------
+# utility
+
+def alignment_entropy(rdr):
+ try:
+ aln = np.array([list(rec) for rec in AlignIO.read(rdr, 'fasta')], np.character).view(np.uint8)
+ S = sum(entropy(np.bincount(aln[:,i])/aln.shape[0]) for i in range(aln.shape[1]))
+ return S/aln.shape[1]
+ except Exception as msg:
+ print(f"ERROR: {msg}")
+ return None
+
+# ------------------------------------------------------------------------
# Junction class
# simple struct
@@ -136,7 +155,7 @@ class Graph(object):
graphs, names = [], []
for name, path in G.seqs.items():
blks = set([b.id for b in path.blocks()])
- gi = [ i for i, g in enumerate(graphs) if overlaps(blks, g)]
+ gi = [i for i, g in enumerate(graphs) if overlaps(blks, g)]
if len(gi) == 0:
graphs.append(blks)
names.append(set([name]))
@@ -159,7 +178,7 @@ class Graph(object):
# ---------------
# methods
- def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False):
+ def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False, edge_window=1000, edge_extend=2500):
from seqanpy import align_global as align
# ----------------------------------
@@ -319,19 +338,11 @@ class Graph(object):
or not accepted(hit):
continue
- merged = True
- new_blks = self.merge(proc(hit))
+ merged = True
+ self.merge(proc(hit), edge_window, edge_extend)
merged_blks.add(hit['ref']['name'])
merged_blks.add(hit['qry']['name'])
- # for blk in new_blks:
- # for iso in blk.isolates:
- # path = self.seqs[iso]
- # x, n = path.position_of(blk)
- # lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path))
- # subpath = path[lb:ub]
- # print(subpath, file=sys.stderr)
- # breakpoint("stop")
self.remove_transitives()
for path in self.seqs.values():
@@ -407,9 +418,6 @@ class Graph(object):
# TODO: check that isos is constant along the chain
for iso in self.blks[c[0][0]].isolates.keys():
self.seqs[iso].merge(c[0], c[-1], new_blk)
- # for n in self.seqs[iso].nodes:
- # if n.blk.id in [e[0] for e in c]:
- # breakpoint("bad deletion")
self.blks[new_blk.id] = new_blk
for b, _ in c:
@@ -421,7 +429,7 @@ class Graph(object):
blks.update(path.blocks())
self.blks = {b.id:self.blks[b.id] for b in blks}
- def merge(self, hit):
+ def merge(self, hit, window, extend):
old_ref = self.blks[hit['ref']['name']]
old_qry = self.blks[hit['qry']['name']]
@@ -444,7 +452,7 @@ class Graph(object):
"qry_name" : hit["qry"]["name"],
"orientation" : hit["orientation"]}
- merged_blks, new_qrys, new_refs, blk_map = Block.from_aln(aln)
+ merged_blks, new_qrys, new_refs, shared_blks, blk_map = Block.from_aln(aln)
for merged_blk in merged_blks:
self.blks[merged_blk.id] = merged_blk
@@ -474,6 +482,133 @@ class Graph(object):
new_blocks = []
new_blocks.extend(update(old_ref, new_refs, hit['ref'], Strand.Plus))
new_blocks.extend(update(old_qry, new_qrys, hit['qry'], hit['orientation']))
+
+ lblks_set_x, rblks_set_x = set(), set()
+ lblks_set_s, rblks_set_s = set(), set()
+ first = True
+ num_seqs = 0
+ for tag in shared_blks[0].muts.keys():
+ pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks]
+ strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks]
+ beg, end = pos[0], pos[-1]
+ if strand[0] == Strand.Plus:
+ lwindow = min(window, shared_blks[0].len_of(*tag))
+ rwindow = min(window, shared_blks[-1].len_of(*tag))
+
+ lblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+lwindow]
+ rblks_x = self.seqs[tag[0]][end[1]-rwindow:end[1]+extend]
+
+ # lblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window]
+ # rblks_s = self.seqs[tag[0]][end[1]-window:end[1]]
+ elif strand[0] == Strand.Minus:
+ lwindow = min(window, shared_blks[-1].len_of(*tag))
+ rwindow = min(window, shared_blks[0].len_of(*tag))
+ rblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+rwindow]
+ lblks_x = self.seqs[tag[0]][end[1]-lwindow:end[1]+extend]
+
+ # rblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window]
+ # lblks_s = self.seqs[tag[0]][end[1]-window:end[1]]
+ else:
+ raise ValueError("unrecognized strand polarity")
+
+ if first:
+ lblks_set_x = set([b.id for b in lblks_x])
+ rblks_set_x = set([b.id for b in rblks_x])
+
+ # lblks_set_s = set([b.id for b in lblks_s])
+ # rblks_set_s = set([b.id for b in rblks_s])
+
+ lblks_set_s = set([b.id for b in lblks_x])
+ rblks_set_s = set([b.id for b in rblks_x])
+
+ first = False
+ else:
+ lblks_set_x.intersection_update(set([b.id for b in lblks_x]))
+ rblks_set_x.intersection_update(set([b.id for b in rblks_x]))
+
+ lblks_set_s.update(set([b.id for b in lblks_x]))
+ rblks_set_s.update(set([b.id for b in rblks_x]))
+ # lblks_set_s.intersection_update(set([b.id for b in lblks_s]))
+ # rblks_set_s.intersection_update(set([b.id for b in rblks_s]))
+ num_seqs += 1
+
+ def emit(side):
+ if side == 'left':
+ delta = len(lblks_set_s)-len(lblks_set_x)
+ elif side == 'right':
+ delta = len(lblks_set_s)-len(rblks_set_x)
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ if delta > 0 and num_seqs > 1:
+ print(f">LEN={delta}", end=';')
+ try:
+ fd, path = tempfile.mkstemp()
+ with os.fdopen(fd, 'w') as tmp:
+ for i, tag in enumerate(merged_blks[0].muts.keys()):
+ pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks]
+ strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks]
+ beg, end = pos[0], pos[-1]
+
+ if strand[0] == Strand.Plus:
+ if side == 'left':
+ left, right = beg[0]-extend,beg[0]+min(window,shared_blks[0].len_of(*tag))
+ elif side == 'right':
+ left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ elif strand[0] == Strand.Minus:
+ if side == 'left':
+ left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend
+ elif side == 'right':
+ left, right = beg[0]-extend,beg[0]+min(window, shared_blks[0].len_of(*tag))
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ iso_blks = self.seqs[tag[0]][left:right]
+ # print("POSITIONS", pos)
+ # print("STRAND", strand)
+ # print("LIST", shared_blks)
+ # print("MERGED", merged_blks)
+ # print("INTERSECTION", lblks_set_x if side == 'left' else rblks_set_x)
+ # print("UNION", lblks_set_s if side == 'left' else rblks_set_s)
+ # print("ISO", iso_blks)
+ # breakpoint("stop")
+ tmp.write(f">isolate_{i:04d} {','.join(b.id for b in iso_blks)}\n")
+ s = self.seqs[tag[0]].sequence_range(left,right)
+ if len(s) > extend + window:
+ breakpoint(f"bad sequence slicing: {len(s)}")
+ tmp.write(s + '\n')
+
+ tmp.flush()
+
+ proc = subprocess.Popen(f"mafft --auto {path}",
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True)
+ # proc[1] = subprocess.Popen(f"fasttree",
+ # stdin =subprocess.PIPE,
+ # stdout=subprocess.PIPE,
+ # stderr=subprocess.PIPE,
+ # shell=True)
+ out, err = proc.communicate()
+ # out[1], err[1] = proc[1].communicate(input=out[0])
+ # tree = Phylo.read(io.StringIO(out[1].decode('utf-8')), format='newick')
+ print(f"ALIGNMENT={out}", end=";")
+ rdr = StringIO(out.decode('utf-8'))
+ print(f"SCORE={alignment_entropy(rdr)}", end=";")
+ rdr.close()
+ # print(f"SCORE={tree.total_branch_length()/(2*num_seqs)}", end=";")
+ print("\n", end="")
+ finally:
+ os.remove(path)
+ else:
+ print(f">NO MATCH")
+
+ emit('left')
+ emit('right')
+
self.prune_blks()
return [b[0] for b in new_blocks]
diff --git a/pangraph/sequence.py b/pangraph/sequence.py
index b2b5ce7..db2b3e6 100644
--- a/pangraph/sequence.py
+++ b/pangraph/sequence.py
@@ -88,12 +88,6 @@ class Path(object):
return seq
- def position_of(self, blk):
- for i, n in enumerate(self.nodes):
- if n.blk == blk:
- return i, n.num
- raise ValueError("block not found in path")
-
def rm_nil_blks(self):
good, popped = [], set()
for i, n in enumerate(self.nodes):
@@ -121,6 +115,8 @@ class Path(object):
try:
i, j = ids.index(start[0]), ids.index(stop[0])
+ if N > 0:
+ breakpoint("HIT")
if self.nodes[i].strand == start[1]:
beg, end, s = i, j, Strand.Plus
else:
@@ -153,17 +149,75 @@ class Path(object):
self.nodes = new
self.position = np.cumsum([0] + [n.length(self.name) for n in self.nodes])
+ def position_of(self, blk, num):
+ index = { n.num:i for i, n in enumerate(self.nodes) if n.blk == blk }
+ if not num in index:
+ return None
+ return (self.position[index[num]], self.position[index[num]+1])
+
+ def orientation_of(self, blk, num):
+ orientation = { n.num:n.strand for i, n in enumerate(self.nodes) if n.blk == blk }
+ if not num in orientation:
+ return None
+ return orientation[num]
+
+ # TODO: pull out common functionality into a helper function
+ # TODO: merge with other sequence function
+ def sequence_range(self, start=None, stop=None):
+ beg = start or 0
+ end = stop or self.position[-1]
+ l, r = "", ""
+ if beg < 0:
+ if len(self.nodes) > 1:
+ l = self.sequence_range(self.position[-1]+beg,self.position[-1])
+ beg = 0
+ if end > self.position[-1]:
+ if len(self.nodes) > 1:
+ r = self.sequence_range(0,end-self.position[-1])
+ end = self.position[-1]
+ if beg > end:
+ beg, end = end, beg
+
+ i = np.searchsorted(self.position, beg, side='right') - 1
+ j = np.searchsorted(self.position, end, side='left')
+ m = ""
+ if i < j:
+ if j >= len(self.position):
+ breakpoint("what?")
+ if i == j - 1:
+ m = self.nodes[i].blk.extract(self.name, self.nodes[i].num)[(beg-self.position[i]):(end-self.position[i])]
+ else:
+ m = self.nodes[i].blk.extract(self.name, self.nodes[i].num)[(beg-self.position[i]):]
+ for n in self.nodes[i+1:j-1]:
+ m += n.blk.extract(self.name, n.num)
+ n = self.nodes[j-1]
+ b = n.blk.extract(self.name, n.num)
+ m += b[0:(end-self.position[j-1])]
+ return l + m + r
+
def __getitem__(self, index):
if isinstance(index, slice):
beg = index.start or 0
end = index.stop or self.position[-1]
-
- i = np.searchsorted(self.position, beg, side='right')
- j = np.searchsorted(self.position, end, side='right') + 1
- assert i < j, "sorted"
- return [n.blk for n in self.nodes[i:j]]
+ l, r = [], []
+ if beg < 0:
+ if len(self.nodes) > 1:
+ l = self[(self.position[-1]+beg):self.position[-1]]
+ beg = 0
+ if end > self.position[-1]:
+ if len(self.nodes) > 1:
+ r = self[0:(end-self.position[-1])]
+ end = self.position[-1]
+ if beg > end:
+ beg, end = end, beg
+
+ i = np.searchsorted(self.position, beg, side='right') - 1
+ j = np.searchsorted(self.position, end, side='left')
+ if i > j:
+ breakpoint(f"not sorted, {beg}-{end}")
+ return l + [n.blk for n in self.nodes[i:j]] + r
elif isinstance(index, int):
- i = np.searchsorted(self.position, index, side='left')
+ i = np.searchsorted(self.position, index, side='right') - 1
return self.nodes[i].blk
else:
raise ValueError(f"type '{type(index)}' not supported as index")
diff --git a/pangraph/tree.py b/pangraph/tree.py
index 605bac3..d36b77d 100644
--- a/pangraph/tree.py
+++ b/pangraph/tree.py
@@ -139,6 +139,11 @@ class Clade(object):
'fapath' : self.fapath,
'graph' : serialize(self.graph) if self.graph is not None else None }
+ def set_level(self, level):
+ for c in self.child:
+ c.set_level(level+1)
+ self.level = level
+
class Tree(object):
# -------------------
# Class constructor
@@ -287,7 +292,8 @@ class Tree(object):
leafs = {n.name: n for n in self.get_leafs()}
self.seqs = {leafs[name]:seq for name,seq in seqs.items()}
- def align(self, tmpdir, min_blk_len, mu, beta, extensive, log_stats=False, verbose=False):
+ def align(self, tmpdir, min_blk_len, mu, beta, extensive, edge_window, edge_extend, log_stats=False, verbose=False):
+ self.root.set_level(0) # NOTE: for debug logging
stats = {}
# ---------------------------------------------
# internal functions
@@ -345,7 +351,7 @@ class Tree(object):
graph1, fapath1 = node1.graph, node1.fapath
graph2, fapath2 = node2.graph, node2.fapath
graph = Graph.fuse(graph1, graph2)
- graph, _ = graph.union(fapath1, fapath2, f"{tmpdir}/{n.name}", min_blk_len, mu, beta, extensive)
+ graph, _ = graph.union(fapath1, fapath2, f"{tmpdir}/{n.name}", min_blk_len, mu, beta, extensive, edge_window, edge_extend)
else:
graph = node1.graph
@@ -355,7 +361,7 @@ class Tree(object):
itr = f"{tmpdir}/{n.name}_iter_{i}"
with open(f"{itr}.fa", 'w') as fd:
graph.write_fasta(fd)
- graph, contin = graph.union(itr, itr, f"{tmpdir}/{n.name}_iter_{i}", min_blk_len, mu, beta, extensive)
+ graph, contin = graph.union(itr, itr, f"{tmpdir}/{n.name}_iter_{i}", min_blk_len, mu, beta, extensive, edge_window, edge_extend)
if not contin:
return graph
return graph
@@ -377,6 +383,7 @@ class Tree(object):
for n in self.postorder():
if n.is_leaf():
continue
+ print(f"+++LEVEL={n.level}+++")
n.fapath = f"{tmpdir}/{n.name}"
log(f"fusing {n.child[0].name} with {n.child[1].name} @ {n.name}")
n.graph = merge(*n.child)
diff --git a/pangraph/utils.py b/pangraph/utils.py
index e0d85fa..e4b63f3 100644
--- a/pangraph/utils.py
+++ b/pangraph/utils.py
@@ -3,6 +3,7 @@ import csv
import gzip
import numpy as np
+from io import StringIO
from enum import IntEnum
from Bio import SeqIO
@@ -97,7 +98,10 @@ def as_array(x):
return np.array(list(x))
def as_string(x):
- return x.view(f'U{x.size}')[0]
+ try:
+ return x.view(f'U{x.size}')[0]
+ except:
+ return "".join(str(c) for c in x)
def flatten(x):
return np.ndarray.flatten(x[:])
@@ -154,9 +158,43 @@ def getnwk(node, newick, parentdist, leaf_names):
newick = "(%s" % (newick)
return newick
+def as_str(s):
+ if isinstance(s, bytes):
+ return s.decode('utf-8')
+ return s
+
# ------------------------------------------------------------------------
# parsers
+def parse_fasta(fh):
+ class Record:
+ def __init__(self, name=None, meta=None, seq=None):
+ self.seq = seq
+ self.name = name
+ self.meta = meta
+
+ def __str__(self):
+ NL = '\n'
+ nc = 80
+ return f">{self.name} {self.meta}\n{NL.join([self.seq[i:(i+nc)] for i in range(0, len(self.seq), nc)])}"
+
+ def __repr__(self):
+ return str(self)
+
+ header = as_str(fh.readline())
+ while header != "" and header[0] == ">":
+ name = header[1:].split()
+ seq = StringIO()
+ for line in fh:
+ line = as_str(line)
+ if line == "" or line[0] == ">":
+ break
+ seq.write(line[:-1])
+
+ header = as_str(line)
+ yield Record(name=name[0], meta=" ".join(name[1:]), seq=seq.getvalue())
+ seq.close()
+
def parse_paf(fh):
hits = []
for line in fh:
diff --git a/scripts/filter_plasmids.py b/scripts/filter_plasmids.py
new file mode 100755
index 0000000..c309150
--- /dev/null
+++ b/scripts/filter_plasmids.py
@@ -0,0 +1,57 @@
+#!/usr/bin/env python3
+"""
+script to filter plasmids and chromosomes from full genome assemblies
+"""
+
+import os
+import sys
+import gzip
+import builtins
+import argparse
+
+from glob import glob
+
+sys.path.insert(0, os.path.abspath('.')) # gross hack
+from pangraph.utils import parse_fasta, breakpoint
+
+def open(path, *args, **kwargs):
+ if path.endswith('.gz'):
+ return gzip.open(path, *args, **kwargs)
+ else:
+ return builtins.open(path, *args, **kwargs)
+
+def main(dirs, plasmids=True):
+ for d in dirs:
+ in_dir = f"data/{d}/assemblies"
+ if not os.path.exists(in_dir):
+ print(f"{in_dir} doesn't exist. skipping...")
+ continue
+
+ if plasmids:
+ out_dir = f"data/{d}-plasmid/assemblies"
+ else:
+ out_dir = f"data/{d}-chromosome/assemblies"
+
+ if not os.path.exists(out_dir):
+ os.makedirs(out_dir)
+
+ for path in glob(f"{in_dir}/*.f?a*"):
+ with open(path, 'rt') as fd, open(f"{out_dir}/{os.path.basename(path).replace('.gz', '')}", 'w') as wtr:
+ for i, rec in enumerate(parse_fasta(fd)):
+ if i == 0:
+ if not plasmids:
+ wtr.write(str(rec))
+ wtr.write('\n')
+ break
+ continue
+
+ wtr.write(str(rec))
+ wtr.write('\n')
+
+parser = argparse.ArgumentParser(description='seperate plasmids from chromosomes')
+parser.add_argument('directories', metavar='dirs', nargs='+')
+parser.add_argument('--chromosomes', default=False, action='store_true')
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ main(args.directories, plasmids=not args.chromosomes)
diff --git a/scripts/parse_log.py b/scripts/parse_log.py
new file mode 100755
index 0000000..65ef8fd
--- /dev/null
+++ b/scripts/parse_log.py
@@ -0,0 +1,113 @@
+#!/usr/bin/env python3
+"""
+script to process our end repair log files for plotting
+"""
+import ast
+import argparse
+
+from io import StringIO
+from Bio import AlignIO
+from collections import defaultdict
+
+import numpy as np
+import matplotlib.pylab as plt
+
+level_preset = "+++LEVEL="
+level_offset = len(level_preset)
+
+score_preset = "SCORE="
+score_offset = len(score_preset)
+
+msaln_preset = "ALIGNMENT="
+msaln_offset = len(msaln_preset)
+
+def unpack(line):
+ offset = [line.find(";")]
+ offset.append(line.find(";", offset[0]+1))
+ offset.append(line.find(";", offset[1]+1))
+
+ align = StringIO(ast.literal_eval(line[msaln_offset+1+offset[0]:offset[1]]).decode('utf-8'))
+ align = next(AlignIO.parse(align, 'fasta'))
+ score = float(line[offset[1]+1+score_offset:offset[2]])
+
+ return align, score
+
+# TODO: remove hardcoded numbers
+def save_aln_examples(results):
+ stats = results[(500,1000)]
+ nums = [0, 0, 0]
+ for score, aln in stats[1]['hits']:
+ if len(aln) < 10:
+ continue
+
+ if score < 1e-4:
+ with open(f"scratch/1/eg_{nums[0]}.fna", 'w+') as fd:
+ AlignIO.write(aln, fd, "fasta")
+ nums[0] += 1
+ elif score < 1e-2:
+ with open(f"scratch/2/eg_{nums[1]}.fna", 'w+') as fd:
+ AlignIO.write(aln, fd, "fasta")
+ nums[1] += 1
+ else:
+ with open(f"scratch/3/eg_{nums[2]}.fna", 'w+') as fd:
+ AlignIO.write(aln, fd, "fasta")
+ nums[2] += 1
+
+def main(args):
+ results = {}
+ for log_path in args:
+ stats = defaultdict(lambda: {'hits':[], 'miss': 0})
+ level = -1
+ with open(log_path) as log:
+ for line in log:
+ line.rstrip('\n')
+ if line[0] == "+":
+ assert line.startswith(level_preset), "check syntax in log file"
+ level = int(line[level_offset:line.find("+++", level_offset)])
+ continue
+ if line[0] == ">":
+ if line[1:].startswith("NO MATCH"):
+ stats[level]['miss'] += 1
+ continue
+ if line[1:].startswith("LEN="):
+ aln, score = unpack(line)
+ stats[level]['hits'].append((score, aln))
+ continue
+ raise ValueError(f"invalid syntax: {line[1:]}")
+ if len(stats) > 0:
+ path = log_path.replace(".log", "").split("-")
+ e, w = int(path[1][1:]), int(path[2][1:])
+ results[(e,w)] = dict(stats)
+ return results
+
+def plot(results):
+ data = results[(500,1000)]
+ scores = [[] for _ in data.keys()]
+ fracs = np.zeros(len(data.keys()))
+ for i, lvl in enumerate(sorted(data.keys())):
+ scores[i] = [elt[0] for elt in data[lvl]['hits']]
+ fracs[i] = len(scores[i])/(len(scores[i]) + data[lvl]['miss'])
+
+ cmap = plt.cm.get_cmap('plasma')
+ colors = [cmap(x) for x in np.linspace(0,1,len(scores[:-5]))]
+ fig, (ax1, ax2) = plt.subplots(1, 2)
+ ax1.plot(fracs[:-5])
+ ax1.set_xlabel("tree level")
+ ax1.set_ylabel("fraction of good edges")
+ for i, score in enumerate(scores[:-5]):
+ ax2.plot(sorted(np.exp(score)), np.linspace(0,1,len(score)),color=colors[i],label=f"level={i+1}")
+ ax2.set_xlabel("average column entropy")
+ ax2.set_ylabel("CDF")
+ ax2.legend()
+
+parser = argparse.ArgumentParser(description='process our data log files on end repair')
+parser.add_argument('files', type=str, nargs='+')
+
+# -----------------------------------------------------------------------------
+# main point of entry
+
+if __name__ == "__main__":
+ args = parser.parse_args()
+ results = main(args.files)
+ save_aln_examples(results)
+ plot(results)