aboutsummaryrefslogtreecommitdiff
path: root/pangraph/graph.py
diff options
context:
space:
mode:
Diffstat (limited to 'pangraph/graph.py')
-rw-r--r--pangraph/graph.py175
1 files changed, 155 insertions, 20 deletions
diff --git a/pangraph/graph.py b/pangraph/graph.py
index bc1e9d9..9f00587 100644
--- a/pangraph/graph.py
+++ b/pangraph/graph.py
@@ -1,15 +1,21 @@
-import os, sys
+import io, os, sys
import json
import numpy as np
import pprint
+import subprocess
+import tempfile
+from io import StringIO
from glob import glob
from collections import defaultdict, Counter
+from itertools import chain
-from Bio import SeqIO, Phylo
+from Bio import AlignIO, SeqIO, Phylo
from Bio.Seq import Seq
from Bio.SeqRecord import SeqRecord
+from scipy.stats import entropy
+
from . import suffix
from .block import Block
from .sequence import Node, Path
@@ -18,10 +24,23 @@ from .utils import Strand, as_string, parse_paf, panic, as_record, new_strand
# ------------------------------------------------------------------------
# globals
-EXTEND = 2500
+# WINDOW = 1000
+# EXTEND = 2500
pp = pprint.PrettyPrinter(indent=4)
# ------------------------------------------------------------------------
+# utility
+
+def alignment_entropy(rdr):
+ try:
+ aln = np.array([list(rec) for rec in AlignIO.read(rdr, 'fasta')], np.character).view(np.uint8)
+ S = sum(entropy(np.bincount(aln[:,i])/aln.shape[0]) for i in range(aln.shape[1]))
+ return S/aln.shape[1]
+ except Exception as msg:
+ print(f"ERROR: {msg}")
+ return None
+
+# ------------------------------------------------------------------------
# Junction class
# simple struct
@@ -136,7 +155,7 @@ class Graph(object):
graphs, names = [], []
for name, path in G.seqs.items():
blks = set([b.id for b in path.blocks()])
- gi = [ i for i, g in enumerate(graphs) if overlaps(blks, g)]
+ gi = [i for i, g in enumerate(graphs) if overlaps(blks, g)]
if len(gi) == 0:
graphs.append(blks)
names.append(set([name]))
@@ -159,7 +178,7 @@ class Graph(object):
# ---------------
# methods
- def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False):
+ def union(self, qpath, rpath, out, cutoff=0, alpha=10, beta=2, extensive=False, edge_window=1000, edge_extend=2500):
from seqanpy import align_global as align
# ----------------------------------
@@ -319,19 +338,11 @@ class Graph(object):
or not accepted(hit):
continue
- merged = True
- new_blks = self.merge(proc(hit))
+ merged = True
+ self.merge(proc(hit), edge_window, edge_extend)
merged_blks.add(hit['ref']['name'])
merged_blks.add(hit['qry']['name'])
- # for blk in new_blks:
- # for iso in blk.isolates:
- # path = self.seqs[iso]
- # x, n = path.position_of(blk)
- # lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path))
- # subpath = path[lb:ub]
- # print(subpath, file=sys.stderr)
- # breakpoint("stop")
self.remove_transitives()
for path in self.seqs.values():
@@ -407,9 +418,6 @@ class Graph(object):
# TODO: check that isos is constant along the chain
for iso in self.blks[c[0][0]].isolates.keys():
self.seqs[iso].merge(c[0], c[-1], new_blk)
- # for n in self.seqs[iso].nodes:
- # if n.blk.id in [e[0] for e in c]:
- # breakpoint("bad deletion")
self.blks[new_blk.id] = new_blk
for b, _ in c:
@@ -421,7 +429,7 @@ class Graph(object):
blks.update(path.blocks())
self.blks = {b.id:self.blks[b.id] for b in blks}
- def merge(self, hit):
+ def merge(self, hit, window, extend):
old_ref = self.blks[hit['ref']['name']]
old_qry = self.blks[hit['qry']['name']]
@@ -444,7 +452,7 @@ class Graph(object):
"qry_name" : hit["qry"]["name"],
"orientation" : hit["orientation"]}
- merged_blks, new_qrys, new_refs, blk_map = Block.from_aln(aln)
+ merged_blks, new_qrys, new_refs, shared_blks, blk_map = Block.from_aln(aln)
for merged_blk in merged_blks:
self.blks[merged_blk.id] = merged_blk
@@ -474,6 +482,133 @@ class Graph(object):
new_blocks = []
new_blocks.extend(update(old_ref, new_refs, hit['ref'], Strand.Plus))
new_blocks.extend(update(old_qry, new_qrys, hit['qry'], hit['orientation']))
+
+ lblks_set_x, rblks_set_x = set(), set()
+ lblks_set_s, rblks_set_s = set(), set()
+ first = True
+ num_seqs = 0
+ for tag in shared_blks[0].muts.keys():
+ pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks]
+ strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks]
+ beg, end = pos[0], pos[-1]
+ if strand[0] == Strand.Plus:
+ lwindow = min(window, shared_blks[0].len_of(*tag))
+ rwindow = min(window, shared_blks[-1].len_of(*tag))
+
+ lblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+lwindow]
+ rblks_x = self.seqs[tag[0]][end[1]-rwindow:end[1]+extend]
+
+ # lblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window]
+ # rblks_s = self.seqs[tag[0]][end[1]-window:end[1]]
+ elif strand[0] == Strand.Minus:
+ lwindow = min(window, shared_blks[-1].len_of(*tag))
+ rwindow = min(window, shared_blks[0].len_of(*tag))
+ rblks_x = self.seqs[tag[0]][beg[0]-extend:beg[0]+rwindow]
+ lblks_x = self.seqs[tag[0]][end[1]-lwindow:end[1]+extend]
+
+ # rblks_s = self.seqs[tag[0]][beg[0]:beg[0]+window]
+ # lblks_s = self.seqs[tag[0]][end[1]-window:end[1]]
+ else:
+ raise ValueError("unrecognized strand polarity")
+
+ if first:
+ lblks_set_x = set([b.id for b in lblks_x])
+ rblks_set_x = set([b.id for b in rblks_x])
+
+ # lblks_set_s = set([b.id for b in lblks_s])
+ # rblks_set_s = set([b.id for b in rblks_s])
+
+ lblks_set_s = set([b.id for b in lblks_x])
+ rblks_set_s = set([b.id for b in rblks_x])
+
+ first = False
+ else:
+ lblks_set_x.intersection_update(set([b.id for b in lblks_x]))
+ rblks_set_x.intersection_update(set([b.id for b in rblks_x]))
+
+ lblks_set_s.update(set([b.id for b in lblks_x]))
+ rblks_set_s.update(set([b.id for b in rblks_x]))
+ # lblks_set_s.intersection_update(set([b.id for b in lblks_s]))
+ # rblks_set_s.intersection_update(set([b.id for b in rblks_s]))
+ num_seqs += 1
+
+ def emit(side):
+ if side == 'left':
+ delta = len(lblks_set_s)-len(lblks_set_x)
+ elif side == 'right':
+ delta = len(lblks_set_s)-len(rblks_set_x)
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ if delta > 0 and num_seqs > 1:
+ print(f">LEN={delta}", end=';')
+ try:
+ fd, path = tempfile.mkstemp()
+ with os.fdopen(fd, 'w') as tmp:
+ for i, tag in enumerate(merged_blks[0].muts.keys()):
+ pos = [self.seqs[tag[0]].position_of(b, tag[1]) for b in shared_blks]
+ strand = [self.seqs[tag[0]].orientation_of(b, tag[1]) for b in shared_blks]
+ beg, end = pos[0], pos[-1]
+
+ if strand[0] == Strand.Plus:
+ if side == 'left':
+ left, right = beg[0]-extend,beg[0]+min(window,shared_blks[0].len_of(*tag))
+ elif side == 'right':
+ left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ elif strand[0] == Strand.Minus:
+ if side == 'left':
+ left, right = end[1]-min(window,shared_blks[-1].len_of(*tag)),end[1]+extend
+ elif side == 'right':
+ left, right = beg[0]-extend,beg[0]+min(window, shared_blks[0].len_of(*tag))
+ else:
+ raise ValueError(f"unrecognized argument '{side}' for side")
+
+ iso_blks = self.seqs[tag[0]][left:right]
+ # print("POSITIONS", pos)
+ # print("STRAND", strand)
+ # print("LIST", shared_blks)
+ # print("MERGED", merged_blks)
+ # print("INTERSECTION", lblks_set_x if side == 'left' else rblks_set_x)
+ # print("UNION", lblks_set_s if side == 'left' else rblks_set_s)
+ # print("ISO", iso_blks)
+ # breakpoint("stop")
+ tmp.write(f">isolate_{i:04d} {','.join(b.id for b in iso_blks)}\n")
+ s = self.seqs[tag[0]].sequence_range(left,right)
+ if len(s) > extend + window:
+ breakpoint(f"bad sequence slicing: {len(s)}")
+ tmp.write(s + '\n')
+
+ tmp.flush()
+
+ proc = subprocess.Popen(f"mafft --auto {path}",
+ stdout=subprocess.PIPE,
+ stderr=subprocess.PIPE,
+ shell=True)
+ # proc[1] = subprocess.Popen(f"fasttree",
+ # stdin =subprocess.PIPE,
+ # stdout=subprocess.PIPE,
+ # stderr=subprocess.PIPE,
+ # shell=True)
+ out, err = proc.communicate()
+ # out[1], err[1] = proc[1].communicate(input=out[0])
+ # tree = Phylo.read(io.StringIO(out[1].decode('utf-8')), format='newick')
+ print(f"ALIGNMENT={out}", end=";")
+ rdr = StringIO(out.decode('utf-8'))
+ print(f"SCORE={alignment_entropy(rdr)}", end=";")
+ rdr.close()
+ # print(f"SCORE={tree.total_branch_length()/(2*num_seqs)}", end=";")
+ print("\n", end="")
+ finally:
+ os.remove(path)
+ else:
+ print(f">NO MATCH")
+
+ emit('left')
+ emit('right')
+
self.prune_blks()
return [b[0] for b in new_blocks]