diff options
Diffstat (limited to 'pangraph/graph.py')
-rw-r--r-- | pangraph/graph.py | 175 |
1 files changed, 155 insertions, 20 deletions
diff --git a/pangraph/graph.py b/pangraph/graph.py index bc1e9d9..9f00587 100644 --- a/pangraph/graph.py +++ b/pangraph/graph.py @@ -1,15 +1,21 @@ -import os, sys +import io, os, sys import json import numpy as np import pprint +import subprocess +import tempfile +from io import StringIO from glob import glob from collections import defaultdict, Counter +from itertools import chain -from Bio import SeqIO, Phylo +from Bio import AlignIO, SeqIO, Phylo from Bio.Seq import Seq from Bio.SeqRecord import SeqRecord +from scipy.stats import entropy + from . import suffix from .block import Block from .sequence import Node, Path @@ -18,10 +24,23 @@ from .utils import Strand, as_string, parse_paf, panic, as_record, new_strand # ------------------------------------------------------------------------ # globals -EXTEND = 2500 +# WINDOW = 1000 +# EXTEND = 2500 pp = pprint.PrettyPrinter(indent=4) # ------------------------------------------------------------------------ +# utility + +def alignment_entropy(rdr): + try: + aln = np.array([list(rec) for rec in AlignIO.read(rdr, 'fasta')], np.character).view(np.uint8) + S = sum(entropy(np.bincount(aln[:,i])/aln.shape[0]) for i in range(aln.shape[1])) + return S/aln.shape[1] + except Exception as msg: + print(f"ERROR: {msg}") + return None + +# ------------------------------------------------------------------------ # Junction class # simple struct @@ -136,7 +155,7 @@ class Graph(object): graphs, names = [], [] for name, path in G.seqs.items(): blks = set([b.id for b in path.blocks()]) - gi = [ i for i, g in enumerate(graphs) if overlaps(blks, g)] + gi = [i for i, g in enumerate(graphs) if overlaps(blks, g)] if len(gi) == 0: graphs.append(blks) names.append(set([name])) @@ -159,7 +178,7 @@ class Graph(object): # --------------- # methods - def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False): + def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False, edge_window=1000, edge_extend=2500): from seqanpy import align_global as align # ---------------------------------- @@ -319,19 +338,11 @@ class Graph(object): or not accepted(hit): continue - merged = True - new_blks = self.merge(proc(hit)) + merged = True + self.merge(proc(hit), edge_window, edge_extend) merged_blks.add(hit['ref']['name']) merged_blks.add(hit['qry']['name']) - # for blk in new_blks: - # for iso in blk.isolates: - # path = self.seqs[iso] - # x, n = path.position_of(blk) - # lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path)) - # subpath = path[lb:ub] - # print(subpath, file=sys.stderr) - # breakpoint("stop") self.remove_transitives() for path in self.seqs.values(): @@ -407,9 +418,6 @@ class Graph(object): # TODO: check that isos is constant along the chain for iso in self.blks[c[0][0]].isolates.keys(): self.seqs[iso].merge(c[0], c[-1], new_blk) - # for n in self.seqs[iso].nodes: - # if n.blk.id in [e[0] for e in c]: - # breakpoint("bad deletion") self.blks[new_blk.id] = new_blk for b, _ in c: @@ -421,7 +429,7 @@ class Graph(object): blks.update(path.blocks()) self.blks = {b.id:self.blks[b.id] for b in blks} - def merge(self, hit): + def merge(self, hit, window, extend): old_ref = self.blks[hit['ref']['name']] old_qry = self.blks[hit['qry']['name']] @@ -444,7 +452,7 @@ class Graph(object): "qry_name" : hit["qry"]["name"], "orientation" : hit["orientation"]} - merged_blks, new_qrys, new_refs, blk_map = Block.from_aln(aln) + merged_blks, new_qrys, new_refs, shared_blks, blk_map = Block.from_aln(aln) for merged_blk in merged_blks: self.blks[merged_blk.id] = merged_blk @@ -474,6 +482,133 @@ class Graph(object): new_blocks = [] new_blocks.extend(update(old_ref, new_refs, hit['ref'], Strand.Plus)) new_blocks.extend(update(old_qry, new_qrys, hit['qry'], hit['orientation'])) + + lblks_set_x, rblks_set_x = set(), set() + lblks_set_s, rblks_set_s = set(), set() + first = True + num_seqs = 0 + for tag in shared_blks[0].muts.keys(): + pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks] + strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks] + beg, end = pos[0], pos[-1] + if strand[0] == Strand.Plus: + lwindow = min(window, shared_blks[0].len_of(*tag)) + rwindow = min(window, shared_blks[-1].len_of(*tag)) + + lblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+lwindow] + rblks_x = self.seqs[tag[0]][end[1]-rwindow:end[1]+extend] + + # lblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window] + # rblks_s = self.seqs[tag[0]][end[1]-window:end[1]] + elif strand[0] == Strand.Minus: + lwindow = min(window, shared_blks[-1].len_of(*tag)) + rwindow = min(window, shared_blks[0].len_of(*tag)) + rblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+rwindow] + lblks_x = self.seqs[tag[0]][end[1]-lwindow:end[1]+extend] + + # rblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window] + # lblks_s = self.seqs[tag[0]][end[1]-window:end[1]] + else: + raise ValueError("unrecognized strand polarity") + + if first: + lblks_set_x = set([b.id for b in lblks_x]) + rblks_set_x = set([b.id for b in rblks_x]) + + # lblks_set_s = set([b.id for b in lblks_s]) + # rblks_set_s = set([b.id for b in rblks_s]) + + lblks_set_s = set([b.id for b in lblks_x]) + rblks_set_s = set([b.id for b in rblks_x]) + + first = False + else: + lblks_set_x.intersection_update(set([b.id for b in lblks_x])) + rblks_set_x.intersection_update(set([b.id for b in rblks_x])) + + lblks_set_s.update(set([b.id for b in lblks_x])) + rblks_set_s.update(set([b.id for b in rblks_x])) + # lblks_set_s.intersection_update(set([b.id for b in lblks_s])) + # rblks_set_s.intersection_update(set([b.id for b in rblks_s])) + num_seqs += 1 + + def emit(side): + if side == 'left': + delta = len(lblks_set_s)-len(lblks_set_x) + elif side == 'right': + delta = len(lblks_set_s)-len(rblks_set_x) + else: + raise ValueError(f"unrecognized argument '{side}' for side") + + if delta > 0 and num_seqs > 1: + print(f">LEN={delta}", end=';') + try: + fd, path = tempfile.mkstemp() + with os.fdopen(fd, 'w') as tmp: + for i, tag in enumerate(merged_blks[0].muts.keys()): + pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks] + strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks] + beg, end = pos[0], pos[-1] + + if strand[0] == Strand.Plus: + if side == 'left': + left, right = beg[0]-extend,beg[0]+min(window,shared_blks[0].len_of(*tag)) + elif side == 'right': + left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend + else: + raise ValueError(f"unrecognized argument '{side}' for side") + + elif strand[0] == Strand.Minus: + if side == 'left': + left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend + elif side == 'right': + left, right = beg[0]-extend,beg[0]+min(window, shared_blks[0].len_of(*tag)) + else: + raise ValueError(f"unrecognized argument '{side}' for side") + + iso_blks = self.seqs[tag[0]][left:right] + # print("POSITIONS", pos) + # print("STRAND", strand) + # print("LIST", shared_blks) + # print("MERGED", merged_blks) + # print("INTERSECTION", lblks_set_x if side == 'left' else rblks_set_x) + # print("UNION", lblks_set_s if side == 'left' else rblks_set_s) + # print("ISO", iso_blks) + # breakpoint("stop") + tmp.write(f">isolate_{i:04d} {','.join(b.id for b in iso_blks)}\n") + s = self.seqs[tag[0]].sequence_range(left,right) + if len(s) > extend + window: + breakpoint(f"bad sequence slicing: {len(s)}") + tmp.write(s + '\n') + + tmp.flush() + + proc = subprocess.Popen(f"mafft --auto {path}", + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + shell=True) + # proc[1] = subprocess.Popen(f"fasttree", + # stdin =subprocess.PIPE, + # stdout=subprocess.PIPE, + # stderr=subprocess.PIPE, + # shell=True) + out, err = proc.communicate() + # out[1], err[1] = proc[1].communicate(input=out[0]) + # tree = Phylo.read(io.StringIO(out[1].decode('utf-8')), format='newick') + print(f"ALIGNMENT={out}", end=";") + rdr = StringIO(out.decode('utf-8')) + print(f"SCORE={alignment_entropy(rdr)}", end=";") + rdr.close() + # print(f"SCORE={tree.total_branch_length()/(2*num_seqs)}", end=";") + print("\n", end="") + finally: + os.remove(path) + else: + print(f">NO MATCH") + + emit('left') + emit('right') + self.prune_blks() return [b[0] for b in new_blocks] |