From 45851d361a6dd81400cca97866526d2dabdc32dd Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Mon, 11 May 2020 14:40:40 -0700 Subject: refactor: changed search on expr nodes to a filter function --- lib/c.py | 104 +++++++++++++++++++++++++++++++++++++-------------------------- 1 file changed, 62 insertions(+), 42 deletions(-) (limited to 'lib') diff --git a/lib/c.py b/lib/c.py index 518952c..b67ea98 100644 --- a/lib/c.py +++ b/lib/c.py @@ -5,7 +5,7 @@ import os import sys from enum import Enum -from typing import List, Tuple, Dict +from typing import Set, List, Tuple, Dict from functools import singledispatch numel = len @@ -700,7 +700,7 @@ def _(ix: Index, i: int): @singledispatch def Visit(s: Stmt, func): - raise TypeError(f"{type(s)} not supported by VarsAccessed operation") + raise TypeError(f"{type(s)} not supported by Visit operation") @Visit.register def _(s: Empty, func): @@ -730,45 +730,67 @@ def _(x: StmtExpr, func): # expression functions @singledispatch -def VarsAccessed(x: Expr, vars: List[Vars]): - raise TypeError(f"{type(s)} not supported by VarsAccessed operation") - -@VarsAccessed.register(Empty) -@VarsAccessed.register(Literal) -def _(x, vars): +def Filter(x: Expr, kind: type, results: List[Expr]): + raise TypeError(f"{type(s)} not supported by Filter operation") + +@Filter.register +def _(x: Empty, kind: type, results: List[Expr]): + if isinstance(x, kind): + results.append(x) + +@Filter.register(Ident) +@Filter.register(Literal) +def _(x, kind:type, results: List[Expr]): + if isinstance(x, kind): + results.append(x) return -@VarsAccessed.register -def _(op: UnaryOp, vars): - return VarsAccessed(op.x, vars) - -@VarsAccessed.register -def _(sym: Ident, vars): - vars.append(sym.var) - -@VarsAccessed.register -def _(s: Assign, vars): - VarsAccessed(s.lhs, vars) - VarsAccessed(s.rhs, vars) - -@VarsAccessed.register -def _(v: Deref, vars): - VarsAccessed(v.x, vars) - -@VarsAccessed.register -def _(i: Index, vars): - VarsAccessed(i.x, vars) - VarsAccessed(i.i, vars) - -@VarsAccessed.register -def _(comma: Comma, vars): - VarsAccessed(comma.expr[0], vars) - VarsAccessed(comma.expr[1], vars) +@Filter.register +def _(x: UnaryOp, kind:type, results: List[Expr]): + if isinstance(x, kind): + results.append(x) + Filter(op.x, kind, results) + +@Filter.register +def _(s: Assign, kind:type, results: List[Expr]): + if isinstance(s, kind): + results.append(s) + Filter(s.lhs, kind, results) + Filter(s.rhs, kind, results) + +@Filter.register +def _(v: Deref, kind:type, results: List[Expr]): + if isinstance(v, kind): + results.append(v) + Filter(v.x, kind, results) + +@Filter.register +def _(i: Index, kind: type, results: List[Expr]): + if isinstance(i, kind): + results.append(i) + Filter(i.x, kind, results) + Filter(i.i, kind, results) + +@Filter.register +def _(comma: Comma, kind: type, results: List[Expr]): + if isinstance(comma, kind): + results.append(comma) + Filter(comma.expr[0], kind, results) + Filter(comma.expr[1], kind, results) + +@Filter.register +def _(op: BinaryOp, kind: type, results: List[Expr]): + if isinstance(op, kind): + results.append(op) + Filter(op.l, kind, results) + Filter(op.r, kind, results) + +def VarsUsed(stmt: Stmt) -> Set[Var]: + vars = [] + Visit(loop.body, lambda node: Filter(node, Ident, vars)) + vars = set([v.var for v in vars]) -@VarsAccessed.register -def _(op: BinaryOp, vars): - VarsAccessed(op.l, vars) - VarsAccessed(op.r, vars) + return vars # ------------------------------------------ # Large scale functions @@ -780,9 +802,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun # pull off needed features of the loop it = loop.init.lhs.var - vars = [] - Visit(loop.body, lambda node: VarsAccessed(node, vars)) - print(vars) + vars = VarsUsed(loop.body) params = [v for v in vars if type(v) == Param] stacks = [v for v in vars if type(v) == Var] @@ -813,7 +833,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func: if isa != SIMD.AVX2: raise ValueError(f"ISA '{isa}' not currently implemented") - vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params, func.vars) + vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) for stmt in func.stmts: if IsLoop(stmt): loop = For(stmt.init, stmt.cond, stmt.step) -- cgit v1.2.1