diff options
author | Nicholas Noll <nbnoll@eml.cc> | 2020-08-12 12:01:33 -0700 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-08-12 12:01:33 -0700 |
commit | a5f38a5350fbc876ddfb9a8fb041d3937083c4c9 (patch) | |
tree | 076c7f17cecfd83acdd2641784b052e5a09b4470 | |
parent | 7ae833aff30dfcd0ef90dcbd528dd15a086b5838 (diff) | |
parent | 8eb9feb037ef53e5c498335f22437195da5d437c (diff) |
Merge pull request #3 from nnoll/feat/merge-transitives
Track repeated junctions to merge transitives
-rw-r--r-- | Makefile | 2 | ||||
-rw-r--r-- | pangraph/block.py | 12 | ||||
-rw-r--r-- | pangraph/build.py | 9 | ||||
-rw-r--r-- | pangraph/graph.py | 153 | ||||
-rw-r--r-- | pangraph/sequence.py | 44 | ||||
-rw-r--r-- | pangraph/tree.py | 7 |
6 files changed, 204 insertions, 23 deletions
@@ -55,7 +55,7 @@ staph: @echo "cluster staph"; \ pangraph cluster -d data/staph data/staph/assemblies/*.fna.gz @echo "build staph"; \ - pangraph build -d data/staph -m 500 -b 0 data/staph/guide.json 1>data/staph/pangraph.json + pangraph build -d data/staph -m 500 -b 0 data/staph/guide.json # 1>data/staph/pangraph.json # figures diff --git a/pangraph/block.py b/pangraph/block.py index 4c438c4..05b7806 100644 --- a/pangraph/block.py +++ b/pangraph/block.py @@ -1,7 +1,7 @@ import numpy as np import numpy.random as rng -from collections import defaultdict +from collections import defaultdict, Counter from .utils import parse_cigar, wcpair, as_array, as_string # ------------------------------------------------------------------------ @@ -28,6 +28,12 @@ class Block(object): self.seq = None self.muts = {} + def __str__(self): + return str(self.id) + + def __repr__(self): + return str(self) + # ------------------ # properties @@ -41,7 +47,7 @@ class Block(object): @property def isolates(self): - return list(set([ k[0] for k in self.muts.keys() ])) + return dict(Counter([k[0] for k in self.muts])) # ------------------ # static methods @@ -69,7 +75,7 @@ class Block(object): @classmethod def cat(cls, blks): - nblk = cls() + nblk = Block() assert all([blks[0].muts.keys() == b2.muts.keys() for b2 in blks[1:]]) nblk.seq = np.concatenate([b.seq for b in blks]) diff --git a/pangraph/build.py b/pangraph/build.py index 14f96d3..01595b7 100644 --- a/pangraph/build.py +++ b/pangraph/build.py @@ -3,7 +3,6 @@ build a pangenome alignment from an annotated guide tree """ import os, sys import builtins -import cProfile from .utils import mkdir, log from .tree import Tree @@ -71,12 +70,10 @@ def main(args): mkdir(tmp) log("aligning") - with cProfile.Profile() as pr: - T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.statistics) - # TODO: when debugging phase is done, remove tmp directory + T.align(tmp, args.len, args.mu, args.beta, args.extensive, args.statistics) + # TODO: when debugging phase is done, remove tmp directory - graphs = T.collect() - pr.dump_stats("perf.prof") + graphs = T.collect() for i, g in enumerate(graphs): log(f"graph {i}: nseqs: {len(g.seqs)} nblks: {len(g.blks)}") diff --git a/pangraph/graph.py b/pangraph/graph.py index 6589443..bc1e9d9 100644 --- a/pangraph/graph.py +++ b/pangraph/graph.py @@ -1,8 +1,10 @@ import os, sys import json import numpy as np +import pprint -from glob import glob +from glob import glob +from collections import defaultdict, Counter from Bio import SeqIO, Phylo from Bio.Seq import Seq @@ -17,6 +19,62 @@ from .utils import Strand, as_string, parse_paf, panic, as_record, new_strand # globals EXTEND = 2500 +pp = pprint.PrettyPrinter(indent=4) + +# ------------------------------------------------------------------------ +# Junction class +# simple struct + +class Junction(object): + def __init__(self, left, right): + self.left = left + self.right = right + + def __eq__(self, other): + if self.data == other.data: + return True + elif self.data == other.reverse().data: + return False + else: + return False + + def __hash__(self): + return hash(frozenset([self.data, self.reverse().data])) + + def __str__(self): + return f"({self.left}, {self.right})" + + def __repr__(self): + return str(self) + + @property + def data(self): + return ((self.left.blk.id, self.left.strand), (self.right.blk.id, self.right.strand)) + + @property + def right_id(self): + return self.right.blk.id + + @property + def left_id(self): + return self.left.blk.id + + @property + def left_blk(self): + return (self.left.blk.id, self.left.strand) + + @property + def right_blk(self): + return (self.right.blk.id, self.right.strand) + + def reverse(self): + return Junction( + Node(self.right.blk, self.right.num, Strand(-1*self.right.strand)), + Node(self.left.blk, self.left.num, Strand(-1*self.left.strand)), + ) + +def rev_blk(b): + return (b[0], Strand(-1*b[1])) # ------------------------------------------------------------------------ # Graph class @@ -266,20 +324,97 @@ class Graph(object): merged_blks.add(hit['ref']['name']) merged_blks.add(hit['qry']['name']) - for blk in new_blks: - for iso in blk.isolates: - path = self.seqs[iso] - x, n = path.position_of(blk) - lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path)) - subpath = path[lb:ub] - print(subpath, file=sys.stderr) - breakpoint("stop") + # for blk in new_blks: + # for iso in blk.isolates: + # path = self.seqs[iso] + # x, n = path.position_of(blk) + # lb, ub = max(0, x-EXTEND), min(x+blk.len_of(iso, n)+EXTEND, len(path)) + # subpath = path[lb:ub] + # print(subpath, file=sys.stderr) + # breakpoint("stop") + self.remove_transitives() for path in self.seqs.values(): path.rm_nil_blks() return self, merged + # a junction is a pair of adjacent blocks. + def junctions(self): + junctions = defaultdict(list) + for iso, path in self.seqs.items(): + if len(path.nodes) == 1: + continue + + for i, n in enumerate(path.nodes): + j = Junction(path.nodes[i-1], n) + junctions[j].append(iso) + return { k:dict(Counter(v)) for k, v in junctions.items() } + + def remove_transitives(self): + js = self.junctions() + transitives = [] + for j, isos in js.items(): + left_eq_right = self.blks[j.left.blk.id].isolates == self.blks[j.right.blk.id].isolates + left_eq_isos = isos == self.blks[j.left.blk.id].isolates + if left_eq_right and left_eq_isos: + transitives.append(j) + + chains = {} + for j in transitives: + if j.left_id in chains and j.right_id in chains: + c1, c2 = chains[j.left_id], chains[j.right_id] + if c1 == c2: + continue + + if j.left_blk==c1[-1] and j.right_blk==c2[0]: + new_chain = c1 + c2 + elif j.left_blk==c1[-1] and rev_blk(j.right_blk)==c2[-1]: + new_chain = c1 + [rev_blk(b) for b in c2[::-1]] + elif rev_blk(j.left_blk)==c1[0] and j.right_blk==c2[0]: + new_chain = [rev_blk(b) for b in c1[::-1]] + c2 + elif rev_blk(j.left_blk)==c1[0] and rev_blk(j.right_blk)==c2[-1]: + new_chain = c2 + c1 + else: + breakpoint("case not covered") + + for b, _ in new_chain: + chains[b] = new_chain + + elif j.left_id in chains: + c = chains[j.left_id] + if j.left_blk == c[-1]: + c.append(j.right_blk) + elif rev_blk(j.left_blk) == c[0]: + c.insert(0, rev_blk(j.right_blk)) + else: + breakpoint("chains should be linear") + elif j.right_id in chains: + c = chains[j.right_id] + if j.right_blk == c[-1]: + c.append(rev_blk(j.left_blk)) + elif j.right_blk == c[0]: + c.insert(0, j.left_blk) + else: + breakpoint("chains should be linear") + else: + chains[j.left_id] = [j.left_blk, j.right_blk] + chains[j.right_id] = chains[j.left_id] + + chains = list({id(c):c for c in chains.values()}.values()) + for c in chains: + new_blk = Block.cat([self.blks[b] if s == Strand.Plus else self.blks[b].rev_cmpl() for b, s in c]) + # TODO: check that isos is constant along the chain + for iso in self.blks[c[0][0]].isolates.keys(): + self.seqs[iso].merge(c[0], c[-1], new_blk) + # for n in self.seqs[iso].nodes: + # if n.blk.id in [e[0] for e in c]: + # breakpoint("bad deletion") + + self.blks[new_blk.id] = new_blk + for b, _ in c: + self.blks.pop(b) + def prune_blks(self): blks = set() for path in self.seqs.values(): diff --git a/pangraph/sequence.py b/pangraph/sequence.py index 6860c1f..b2b5ce7 100644 --- a/pangraph/sequence.py +++ b/pangraph/sequence.py @@ -14,6 +14,18 @@ class Node(object): self.num = num self.strand = strand + def __str__(self): + return f"({self.blk}, {self.num}, {self.strand})" + + def __repr__(self): + return str(self) + + def __eq__(self, other): + return self.blk.id == other.blk.id and self.strand == other.strand + + def __hash__(self): + return hash((self.blk.id, self.strand)) + @classmethod def from_dict(cls, d, blks): N = Node() @@ -41,6 +53,12 @@ class Path(object): self.offset = offset self.position = np.cumsum([0] + [n.length(name) for n in self.nodes]) + def __str__(self): + return f"{self.name}: {[str(n) for n in self.nodes]}" + + def __repr__(self): + return str(self) + @classmethod def from_dict(cls, d): P = Path() @@ -59,7 +77,7 @@ class Path(object): def sequence(self, verbose=False): seq = "" for n in self.nodes: - s = n.blk.extract(self.name, n.num, strip_gaps=False, verbose=verbose) + s = n.blk.extract(self.name, n.num, strip_gaps=True, verbose=verbose) if n.strand == Strand.Plus: seq += s else: @@ -95,6 +113,30 @@ class Path(object): self.nodes = [self.nodes[i] for i in good] self.position = np.cumsum([0] + [n.length(self.name) for n in self.nodes]) + # TODO: debug cases w/ multiple runs + def merge(self, start, stop, new): + N = 0 + while True: + ids = [n.blk.id for n in self.nodes] + try: + i, j = ids.index(start[0]), ids.index(stop[0]) + + if self.nodes[i].strand == start[1]: + beg, end, s = i, j, Strand.Plus + else: + beg, end, s = j, i, Strand.Minus + + if beg < end: + self.nodes = self.nodes[:beg] + [Node(new, N, s)] + self.nodes[end+1:] + else: + self.offset += sum(n.blk.len_of(self.name, N) for n in self.nodes[beg:]) + self.nodes = [Node(new, N, s)] + self.nodes[end+1:beg] + self.position = np.cumsum([0] + [n.length(self.name) for n in self.nodes]) + + N += 1 + except: + return + def replace(self, blk, tag, new_blks, blk_map): new = [] for n in self.nodes: diff --git a/pangraph/tree.py b/pangraph/tree.py index ba93fe8..605bac3 100644 --- a/pangraph/tree.py +++ b/pangraph/tree.py @@ -304,6 +304,7 @@ class Tree(object): rec = G.extract(n.name) uncompressed_length += len(orig) if orig != rec: + breakpoint("inconsistency") nerror += 1 with open("test.fa", "w+") as out: @@ -320,15 +321,15 @@ class Tree(object): pos = [0] seq = G.seqs[n.name] for nn in seq.nodes: - pos.append(pos[-1] + len(G.blks[nn.id].extract(n.name, nn.num))) + pos.append(pos[-1] + len(G.blks[nn.blk.id].extract(n.name, nn.num))) pos = pos[1:] testseqs = [] for nn in G.seqs[n.name].nodes: if nn.strand == Strand.Plus: - testseqs.append("".join(G.blks[nn.id].extract(n.name, nn.num))) + testseqs.append("".join(G.blks[nn.blk.id].extract(n.name, nn.num))) else: - testseqs.append("".join(rev_cmpl(G.blks[nn.id].extract(n.name, nn.num)))) + testseqs.append("".join(rev_cmpl(G.blks[nn.blk.id].extract(n.name, nn.num)))) if nerror == 0: log("all sequences correctly reconstructed") |