aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-08-12 12:01:33 -0700
committerGitHub <noreply@github.com>2020-08-12 12:01:33 -0700
commita5f38a5350fbc876ddfb9a8fb041d3937083c4c9 (patch)
tree076c7f17cecfd83acdd2641784b052e5a09b4470
parent7ae833aff30dfcd0ef90dcbd528dd15a086b5838 (diff)
parent8eb9feb037ef53e5c498335f22437195da5d437c (diff)
Merge pull request #3 from nnoll/feat/merge-transitives
Track repeated junctions to merge transitives
-rw-r--r--Makefile2
-rw-r--r--pangraph/block.py12
-rw-r--r--pangraph/build.py9
-rw-r--r--pangraph/graph.py153
-rw-r--r--pangraph/sequence.py44
-rw-r--r--pangraph/tree.py7
6 files changed, 204 insertions, 23 deletions
diff --git a/Makefile b/Makefile
index b73c6ff..9681f3b 100644
--- a/Makefile
+++ b/Makefile
@@ -55,7 +55,7 @@ staph:
@echo "cluster staph"; \
pangraph cluster -d data/staph data/staph/assemblies/*.fna.gz
@echo "build staph"; \
- pangraph build -d data/staph -m 500 -b 0 data/staph/guide.json 1>data/staph/pangraph.json
+ pangraph build -d data/staph -m 500 -b 0 data/staph/guide.json # 1>data/staph/pangraph.json
# figures
diff --git a/pangraph/block.py b/pangraph/block.py
index 4c438c4..05b7806 100644
--- a/pangraph/block.py
+++ b/pangraph/block.py
@@ -1,7 +1,7 @@
import numpy as np
import numpy.random as rng
-from collections import defaultdict
+from collections import defaultdict, Counter
from .utils import parse_cigar, wcpair, as_array, as_string
# ------------------------------------------------------------------------
@@ -28,6 +28,12 @@ class Block(object):
self.seq = None
self.muts = {}
+ def __str__(self):
+ return str(self.id)
+
+ def __repr__(self):
+ return str(self)
+
# ------------------
# properties
@@ -41,7 +47,7 @@ class Block(object):
@property
def isolates(self):
- return list(set([ k[0] for k in self.muts.keys() ]))
+ return dict(Counter([k[0] for k in self.muts]))
# ------------------
# static methods
@@ -69,7 +75,7 @@ class Block(object):
@classmethod
def cat(cls, blks):
- nblk = cls()
+ nblk = Block()
assert all([blks[0].muts.keys() == b2.muts.keys() for b2 in blks[1:]])
nblk.seq = np.concatenate([b.seq for b in blks])
diff --git a/pangraph/build.py b/pangraph/build.py
index 14f96d3..01595b7 100644
--- a/pangraph/build.py
+++ b/pangraph/build.py
@@ -3,7 +3,6 @@ build a pangenome alignment from an annotated guide tree
"""
import os, sys
import builtins
-import cProfile
from .utils import mkdir, log
from .tree import Tree
@@ -71,12 +70,10 @@ def main(args):
mkdir(tmp)
log("aligning")
- with cProfile.Profile() as pr:
- T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.statistics)
- # TODO: when debugging phase is done, remove tmp directory
+ T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.statistics)
+ # TODO: when debugging phase is done, remove tmp directory
- graphs = T.collect()
- pr.dump_stats("perf.prof")
+ graphs = T.collect()
for i, g in enumerate(graphs):
log(f"graph {i}: nseqs: {len(g.seqs)} nblks: {len(g.blks)}")
diff --git a/pangraph/graph.py b/pangraph/graph.py
index 6589443..bc1e9d9 100644
--- a/pangraph/graph.py
+++ b/pangraph/graph.py
@@ -1,8 +1,10 @@
import os, sys
import json
import numpy as np
+import pprint
-from glob import glob
+from glob import glob
+from collections import defaultdict, Counter
from Bio import SeqIO, Phylo
from Bio.Seq import Seq
@@ -17,6 +19,62 @@ from .utils import Strand, as_string, parse_paf, panic, as_record, new_strand
# globals
EXTEND = 2500
+pp = pprint.PrettyPrinter(indent=4)
+
+# ------------------------------------------------------------------------
+# Junction class
+# simple struct
+
+class Junction(object):
+ def __init__(self, left, right):
+ self.left = left
+ self.right = right
+
+ def __eq__(self, other):
+ if self.data == other.data:
+ return True
+ elif self.data == other.reverse().data:
+ return False
+ else:
+ return False
+
+ def __hash__(self):
+ return hash(frozenset([self.data, self.reverse().data]))
+
+ def __str__(self):
+ return f"({self.left}, {self.right})"
+
+ def __repr__(self):
+ return str(self)
+
+ @property
+ def data(self):
+ return ((self.left.blk.id, self.left.strand), (self.right.blk.id, self.right.strand))
+
+ @property
+ def right_id(self):
+ return self.right.blk.id
+
+ @property
+ def left_id(self):
+ return self.left.blk.id
+
+ @property
+ def left_blk(self):
+ return (self.left.blk.id, self.left.strand)
+
+ @property
+ def right_blk(self):
+ return (self.right.blk.id, self.right.strand)
+
+ def reverse(self):
+ return Junction(
+ Node(self.right.blk, self.right.num, Strand(-1*self.right.strand)),
+ Node(self.left.blk, self.left.num, Strand(-1*self.left.strand)),
+ )
+
+def rev_blk(b):
+ return (b[0], Strand(-1*b[1]))
# ------------------------------------------------------------------------
# Graph class
@@ -266,20 +324,97 @@ class Graph(object):
merged_blks.add(hit['ref']['name'])
merged_blks.add(hit['qry']['name'])
- for blk in new_blks:
- for iso in blk.isolates:
- path = self.seqs[iso]
- x, n = path.position_of(blk)
- lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path))
- subpath = path[lb:ub]
- print(subpath, file=sys.stderr)
- breakpoint("stop")
+ # for blk in new_blks:
+ # for iso in blk.isolates:
+ # path = self.seqs[iso]
+ # x, n = path.position_of(blk)
+ # lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path))
+ # subpath = path[lb:ub]
+ # print(subpath, file=sys.stderr)
+ # breakpoint("stop")
+ self.remove_transitives()
for path in self.seqs.values():
path.rm_nil_blks()
return self, merged
+ # a junction is a pair of adjacent blocks.
+ def junctions(self):
+ junctions = defaultdict(list)
+ for iso, path in self.seqs.items():
+ if len(path.nodes) == 1:
+ continue
+
+ for i, n in enumerate(path.nodes):
+ j = Junction(path.nodes[i-1], n)
+ junctions[j].append(iso)
+ return { k:dict(Counter(v)) for k, v in junctions.items() }
+
+ def remove_transitives(self):
+ js = self.junctions()
+ transitives = []
+ for j, isos in js.items():
+ left_eq_right = self.blks[j.left.blk.id].isolates == self.blks[j.right.blk.id].isolates
+ left_eq_isos = isos == self.blks[j.left.blk.id].isolates
+ if left_eq_right and left_eq_isos:
+ transitives.append(j)
+
+ chains = {}
+ for j in transitives:
+ if j.left_id in chains and j.right_id in chains:
+ c1, c2 = chains[j.left_id], chains[j.right_id]
+ if c1 == c2:
+ continue
+
+ if j.left_blk==c1[-1] and j.right_blk==c2[0]:
+ new_chain = c1 + c2
+ elif j.left_blk==c1[-1] and rev_blk(j.right_blk)==c2[-1]:
+ new_chain = c1 + [rev_blk(b) for b in c2[::-1]]
+ elif rev_blk(j.left_blk)==c1[0] and j.right_blk==c2[0]:
+ new_chain = [rev_blk(b) for b in c1[::-1]] + c2
+ elif rev_blk(j.left_blk)==c1[0] and rev_blk(j.right_blk)==c2[-1]:
+ new_chain = c2 + c1
+ else:
+ breakpoint("case not covered")
+
+ for b, _ in new_chain:
+ chains[b] = new_chain
+
+ elif j.left_id in chains:
+ c = chains[j.left_id]
+ if j.left_blk == c[-1]:
+ c.append(j.right_blk)
+ elif rev_blk(j.left_blk) == c[0]:
+ c.insert(0, rev_blk(j.right_blk))
+ else:
+ breakpoint("chains should be linear")
+ elif j.right_id in chains:
+ c = chains[j.right_id]
+ if j.right_blk == c[-1]:
+ c.append(rev_blk(j.left_blk))
+ elif j.right_blk == c[0]:
+ c.insert(0, j.left_blk)
+ else:
+ breakpoint("chains should be linear")
+ else:
+ chains[j.left_id] = [j.left_blk, j.right_blk]
+ chains[j.right_id] = chains[j.left_id]
+
+ chains = list({id(c):c for c in chains.values()}.values())
+ for c in chains:
+ new_blk = Block.cat([self.blks[b] if s == Strand.Plus else self.blks[b].rev_cmpl() for b, s in c])
+ # TODO: check that isos is constant along the chain
+ for iso in self.blks[c[0][0]].isolates.keys():
+ self.seqs[iso].merge(c[0], c[-1], new_blk)
+ # for n in self.seqs[iso].nodes:
+ # if n.blk.id in [e[0] for e in c]:
+ # breakpoint("bad deletion")
+
+ self.blks[new_blk.id] = new_blk
+ for b, _ in c:
+ self.blks.pop(b)
+
def prune_blks(self):
blks = set()
for path in self.seqs.values():
diff --git a/pangraph/sequence.py b/pangraph/sequence.py
index 6860c1f..b2b5ce7 100644
--- a/pangraph/sequence.py
+++ b/pangraph/sequence.py
@@ -14,6 +14,18 @@ class Node(object):
self.num = num
self.strand = strand
+ def __str__(self):
+ return f"({self.blk}, {self.num}, {self.strand})"
+
+ def __repr__(self):
+ return str(self)
+
+ def __eq__(self, other):
+ return self.blk.id == other.blk.id and self.strand == other.strand
+
+ def __hash__(self):
+ return hash((self.blk.id, self.strand))
+
@classmethod
def from_dict(cls, d, blks):
N = Node()
@@ -41,6 +53,12 @@ class Path(object):
self.offset = offset
self.position = np.cumsum([0] + [n.length(name) for n in self.nodes])
+ def __str__(self):
+ return f"{self.name}: {[str(n) for n in self.nodes]}"
+
+ def __repr__(self):
+ return str(self)
+
@classmethod
def from_dict(cls, d):
P = Path()
@@ -59,7 +77,7 @@ class Path(object):
def sequence(self, verbose=False):
seq = ""
for n in self.nodes:
- s = n.blk.extract(self.name, n.num, strip_gaps=False, verbose=verbose)
+ s = n.blk.extract(self.name, n.num, strip_gaps=True, verbose=verbose)
if n.strand == Strand.Plus:
seq += s
else:
@@ -95,6 +113,30 @@ class Path(object):
self.nodes = [self.nodes[i] for i in good]
self.position = np.cumsum([0] + [n.length(self.name) for n in self.nodes])
+ # TODO: debug cases w/ multiple runs
+ def merge(self, start, stop, new):
+ N = 0
+ while True:
+ ids = [n.blk.id for n in self.nodes]
+ try:
+ i, j = ids.index(start[0]), ids.index(stop[0])
+
+ if self.nodes[i].strand == start[1]:
+ beg, end, s = i, j, Strand.Plus
+ else:
+ beg, end, s = j, i, Strand.Minus
+
+ if beg < end:
+ self.nodes = self.nodes[:beg] + [Node(new, N, s)] + self.nodes[end+1:]
+ else:
+ self.offset += sum(n.blk.len_of(self.name, N) for n in self.nodes[beg:])
+ self.nodes = [Node(new, N, s)] + self.nodes[end+1:beg]
+ self.position = np.cumsum([0] + [n.length(self.name) for n in self.nodes])
+
+ N += 1
+ except:
+ return
+
def replace(self, blk, tag, new_blks, blk_map):
new = []
for n in self.nodes:
diff --git a/pangraph/tree.py b/pangraph/tree.py
index ba93fe8..605bac3 100644
--- a/pangraph/tree.py
+++ b/pangraph/tree.py
@@ -304,6 +304,7 @@ class Tree(object):
rec = G.extract(n.name)
uncompressed_length += len(orig)
if orig != rec:
+ breakpoint("inconsistency")
nerror += 1
with open("test.fa", "w+") as out:
@@ -320,15 +321,15 @@ class Tree(object):
pos = [0]
seq = G.seqs[n.name]
for nn in seq.nodes:
- pos.append(pos[-1] + len(G.blks[nn.id].extract(n.name, nn.num)))
+ pos.append(pos[-1] + len(G.blks[nn.blk.id].extract(n.name, nn.num)))
pos = pos[1:]
testseqs = []
for nn in G.seqs[n.name].nodes:
if nn.strand == Strand.Plus:
- testseqs.append("".join(G.blks[nn.id].extract(n.name, nn.num)))
+ testseqs.append("".join(G.blks[nn.blk.id].extract(n.name, nn.num)))
else:
- testseqs.append("".join(rev_cmpl(G.blks[nn.id].extract(n.name, nn.num))))
+ testseqs.append("".join(rev_cmpl(G.blks[nn.blk.id].extract(n.name, nn.num))))
if nerror == 0:
log("all sequences correctly reconstructed")