aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-09-10 13:55:29 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-09-10 13:55:29 -0700
commitd1712b5a2bbeb61c0660c57e76ccb0e62230edee (patch)
tree533583d8c4b8d03e98fa69a0825d3044b8ce3849
parent78639f76e10810296607df2e9b3f839117185cde (diff)
fix: more consistent handling of left/right extension
-rw-r--r--Makefile4
-rw-r--r--pangraph/graph.py134
-rwxr-xr-xscripts/parse_log.py50
3 files changed, 105 insertions, 83 deletions
diff --git a/Makefile b/Makefile
index b0b9ffb..d1a401d 100644
--- a/Makefile
+++ b/Makefile
@@ -55,8 +55,8 @@ staph:
@echo "cluster staph"; \
pangraph cluster -d data/staph data/staph/assemblies/*.fna.gz
@echo "build staph"; \
- pangraph build -d data/staph -m 500 -b 0 -e 2500 -w 1000 data/staph/guide.json 2>staph-e2500-w1000.err 1>staph-e2500-w1000.log
-
+ pangraph build -d data/staph -m 500 -b 0 -e 2500 -w 1000 data/staph/guide.json
+# 2>staph-e2500-w1000.err 1>staph-e2500-w1000.log
# figures
# figs/figure1.png: $(STAT)
diff --git a/pangraph/graph.py b/pangraph/graph.py
index d154f74..cc0c8e7 100644
--- a/pangraph/graph.py
+++ b/pangraph/graph.py
@@ -468,60 +468,77 @@ class Graph(object):
new_blocks.extend(update(old_ref, new_refs, hit['ref'], Strand.Plus))
new_blocks.extend(update(old_qry, new_qrys, hit['qry'], hit['orientation']))
- blk_list = set()
+ lblks_set_x, rblks_set_x = set(), set()
+ lblks_set_s, rblks_set_s = set(), set()
first = True
num_seqs = 0
- for tag in ref.muts.keys():
+ for tag in shared_blks[0].muts.keys():
# NOTE: this is a hack to deal with flipped orientations
- pos = sorted([self.seqs[tag[0]].position_of(b, tag[1]) for b in new_refs], key=lambda x: x[0])
+ pos = sorted([self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks], key=lambda x: x[0])
beg, end = pos[0], pos[-1]
- blks = self.seqs[tag[0]][beg[0]-extend:end[1]+extend]
+ lblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+window]
+ rblks_x = self.seqs[tag[0]][end[1]-window:end[1]+extend]
+
+ lblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window]
+ rblks_s = self.seqs[tag[0]][end[1]-window:end[1]]
if first:
- blk_list = set([b.id for b in blks])
+ lblks_set_x = set([b.id for b in lblks_x])
+ rblks_set_x = set([b.id for b in rblks_x])
+
+ lblks_set_s = set([b.id for b in lblks_s])
+ rblks_set_s = set([b.id for b in rblks_s])
+
first = False
else:
- blk_list.intersection_update(set([b.id for b in blks]))
- num_seqs += 1
+ lblks_set_x.intersection_update(set([b.id for b in lblks_x]))
+ rblks_set_x.intersection_update(set([b.id for b in rblks_x]))
- for tag in qry.muts.keys():
- try:
- pos = sorted([self.seqs[tag[0]].position_of(b, tag[1]) for b in new_qrys], key=lambda x: x[0])
- except:
- breakpoint("bad find")
- beg, end = pos[0], pos[-1]
- blks = self.seqs[tag[0]][beg[0]-extend:end[1]+extend]
- blk_list.intersection_update(set([b.id for b in blks]))
+ lblks_set_s.intersection_update(set([b.id for b in lblks_s]))
+ rblks_set_s.intersection_update(set([b.id for b in rblks_s]))
num_seqs += 1
- delta = len(blk_list)-len(shared_blks)
- if delta > 0 and num_seqs > 1:
- print(f">LEN={delta}", end=';')
- fd = [None, None]
- path = [None, None]
- try:
- fd[0], path[0] = tempfile.mkstemp()
- fd[1], path[1] = tempfile.mkstemp()
- with os.fdopen(fd[0], 'w') as tmp1, os.fdopen(fd[1], 'w') as tmp2:
- tmp = [tmp1, tmp2]
- for i, tag in enumerate(chain(ref.muts.keys(), qry.muts.keys())):
- pos = sorted([self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks], key=lambda x: x[0])
- beg, end = pos[0], pos[-1]
-
- for n, (left, right) in enumerate([(beg[0]-extend,beg[0]+window), (end[1]-window,end[1]+extend)]):
- tmp[n].write(f">isolate_{i:04d}\n")
+ def emit(side):
+ if side == 'left':
+ delta = len(lblks_set_x)-len(lblks_set_s)
+ elif side == 'right':
+ delta = len(rblks_set_x)-len(rblks_set_s)
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ if delta > 0 and num_seqs > 1:
+ print(f">LEN={delta}", end=';')
+ try:
+ fd, path = tempfile.mkstemp()
+ with os.fdopen(fd, 'w') as tmp:
+ for i, tag in enumerate(merged_blks[0].muts.keys()):
+ pos = sorted([self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks], key=lambda x: x[0])
+ beg, end = pos[0], pos[-1]
+
+ if side == 'left':
+ left, right = beg[0]-extend,beg[0]+window
+ elif side == 'right':
+ left, right = end[1]-window,end[1]+extend
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ iso_blks = self.seqs[tag[0]][left:right]
+ print("POSITIONS", pos)
+ print("LIST", lblks_set_x if side == 'left' else rblks_set_x)
+ print("SHARED", lblks_set_s if side == 'left' else rblks_set_s)
+ print("ISO", iso_blks)
+ breakpoint("stop")
+ tmp.write(f">isolate_{i:04d} {','.join(b.id for b in iso_blks)}\n")
s = self.seqs[tag[0]].sequence_range(left,right)
if len(s) > extend + window:
breakpoint(f"bad sequence slicing: {len(s)}")
- tmp[n].write(s + '\n')
+ tmp.write(s + '\n')
- tmp[0].flush()
- tmp[1].flush()
+ tmp.flush()
- def make_tree(n):
proc = [None, None]
out = [None, None]
err = [None, None]
- proc[0] = subprocess.Popen(f"mafft --auto {path[n]}",
+ proc[0] = subprocess.Popen(f"mafft --auto {path}",
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
shell=True)
@@ -533,45 +550,16 @@ class Graph(object):
out[0], err[0] = proc[0].communicate()
out[1], err[1] = proc[1].communicate(input=out[0])
tree = Phylo.read(io.StringIO(out[1].decode('utf-8')), format='newick')
- print(f"ALIGNMENT=\n{out[0]}")
+ print(f"ALIGNMENT=\n{out[0]}", end=";")
print(f"SCORE={tree.total_branch_length()/(2*num_seqs)}", end=";")
+ print("\n", end="")
+ finally:
+ os.remove(path)
+ else:
+ print(f">NO MATCH")
- make_tree(0)
- make_tree(1)
- print("\n", end="")
- finally:
- os.remove(path[0])
- os.remove(path[1])
- else:
- print(f">NO MATCH")
-
- # NOTE: debugging code
- # if len(blk_list) < len(shared_blks):
- # ref_list = set()
- # first = True
- # for tag in ref.muts.keys():
- # beg = self.seqs[tag[0]].position_of(new_refs[0], tag[1])
- # end = self.seqs[tag[0]].position_of(new_refs[-1], tag[1])
- # blks = self.seqs[tag[0]][beg[0]-EXTEND:end[1]+EXTEND]
- # if first:
- # ref_list = set([b.id for b in blks])
- # first = False
- # else:
- # ref_list.intersection_update(set([b.id for b in blks]))
-
- # qry_list = set()
- # first = True
- # for tag in qry.muts.keys():
- # beg = self.seqs[tag[0]].position_of(new_qrys[0], tag[1])
- # end = self.seqs[tag[0]].position_of(new_qrys[-1], tag[1])
- # blks = self.seqs[tag[0]][beg[0]-EXTEND:end[1]+EXTEND]
- # if first:
- # qry_list = set([b.id for b in blks])
- # first = False
- # else:
- # qry_list.intersection_update(set([b.id for b in blks]))
-
- # breakpoint("inconsistent number of blocks")
+ emit('left')
+ emit('right')
self.prune_blks()
diff --git a/scripts/parse_log.py b/scripts/parse_log.py
index 3a0aed0..8532d3d 100755
--- a/scripts/parse_log.py
+++ b/scripts/parse_log.py
@@ -2,7 +2,11 @@
"""
script to process our end repair log files for plotting
"""
+import ast
import argparse
+
+from io import StringIO
+from Bio import AlignIO
from collections import defaultdict
level_preset = "+++LEVEL="
@@ -11,6 +15,40 @@ level_offset = len(level_preset)
score_preset = "SCORE="
score_offset = len(score_preset)
+msaln_preset = "ALIGNMENT="
+msaln_offset = len(msaln_preset)
+
+def unpack(line):
+ offset = [line.find(";")]
+ offset.append(line.find(";", offset[0]+1))
+
+ align = StringIO(ast.literal_eval(line[msaln_offset:offset[0]]).decode('utf-8'))
+ align = next(AlignIO.parse(align, 'fasta'))
+ score = float(line[offset[0]+1+score_offset:offset[1]])
+
+ return align, score
+
+# TODO: remove hardcoded numbers
+def save_aln_examples(results):
+ stats = results[(2500,1000)]
+ nums = [0, 0, 0]
+ for score, aln in stats[1]['hits']:
+ if len(aln) < 10:
+ continue
+
+ if score < 1e-4:
+ with open(f"scratch/1/eg_{nums[0]}.fna", 'w') as fd:
+ AlignIO.write(aln, fd, "fasta")
+ nums[0] += 1
+ elif score < 1e-2:
+ with open(f"scratch/2/eg_{nums[1]}.fna", 'w') as fd:
+ AlignIO.write(aln, fd, "fasta")
+ nums[1] += 1
+ else:
+ with open(f"scratch/3/eg_{nums[2]}.fna", 'w') as fd:
+ AlignIO.write(aln, fd, "fasta")
+ nums[2] += 1
+
def main(args):
results = {}
for log_path in args:
@@ -28,14 +66,9 @@ def main(args):
stats[level]['miss'] += 1
continue
if line[1:].startswith("LEN="):
- offset = [line.find(";")]
- offset.append(line.find(";", offset[0]+1))
- offset.append(line.find(";", offset[1]+1))
-
- score = [None, None]
- score[0] = float(line[offset[0]+1+score_offset:offset[1]])
- score[1] = float(line[offset[1]+1+score_offset:offset[2]])
- stats[level]['hits'].extend(score)
+ laln, lscore = unpack(log.readline())
+ raln, rscore = unpack(log.readline())
+ stats[level]['hits'].extend([(lscore, laln), (rscore, raln)])
continue
raise ValueError(f"invalid syntax: {line[1:]}")
if len(stats) > 0:
@@ -50,3 +83,4 @@ parser.add_argument('files', type=str, nargs='+')
if __name__ == "__main__":
args = parser.parse_args()
results = main(args.files)
+ save_aln_examples(results)