aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-11 14:40:40 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-11 14:40:40 -0700
commit45851d361a6dd81400cca97866526d2dabdc32dd (patch)
tree9a5d70a0ca206e9583f7ec5fa3c32e42bb8a8479
parentfd50eff732d45ba2c4400fabf301c65c776f2338 (diff)
refactor: changed search on expr nodes to a filter function
-rw-r--r--lib/c.py104
1 files changed, 62 insertions, 42 deletions
diff --git a/lib/c.py b/lib/c.py
index 518952c..b67ea98 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -5,7 +5,7 @@ import os
import sys
from enum import Enum
-from typing import List, Tuple, Dict
+from typing import Set, List, Tuple, Dict
from functools import singledispatch
numel = len
@@ -700,7 +700,7 @@ def _(ix: Index, i: int):
@singledispatch
def Visit(s: Stmt, func):
- raise TypeError(f"{type(s)} not supported by VarsAccessed operation")
+ raise TypeError(f"{type(s)} not supported by Visit operation")
@Visit.register
def _(s: Empty, func):
@@ -730,45 +730,67 @@ def _(x: StmtExpr, func):
# expression functions
@singledispatch
-def VarsAccessed(x: Expr, vars: List[Vars]):
- raise TypeError(f"{type(s)} not supported by VarsAccessed operation")
-
-@VarsAccessed.register(Empty)
-@VarsAccessed.register(Literal)
-def _(x, vars):
+def Filter(x: Expr, kind: type, results: List[Expr]):
+ raise TypeError(f"{type(s)} not supported by Filter operation")
+
+@Filter.register
+def _(x: Empty, kind: type, results: List[Expr]):
+ if isinstance(x, kind):
+ results.append(x)
+
+@Filter.register(Ident)
+@Filter.register(Literal)
+def _(x, kind:type, results: List[Expr]):
+ if isinstance(x, kind):
+ results.append(x)
return
-@VarsAccessed.register
-def _(op: UnaryOp, vars):
- return VarsAccessed(op.x, vars)
-
-@VarsAccessed.register
-def _(sym: Ident, vars):
- vars.append(sym.var)
-
-@VarsAccessed.register
-def _(s: Assign, vars):
- VarsAccessed(s.lhs, vars)
- VarsAccessed(s.rhs, vars)
-
-@VarsAccessed.register
-def _(v: Deref, vars):
- VarsAccessed(v.x, vars)
-
-@VarsAccessed.register
-def _(i: Index, vars):
- VarsAccessed(i.x, vars)
- VarsAccessed(i.i, vars)
-
-@VarsAccessed.register
-def _(comma: Comma, vars):
- VarsAccessed(comma.expr[0], vars)
- VarsAccessed(comma.expr[1], vars)
+@Filter.register
+def _(x: UnaryOp, kind:type, results: List[Expr]):
+ if isinstance(x, kind):
+ results.append(x)
+ Filter(op.x, kind, results)
+
+@Filter.register
+def _(s: Assign, kind:type, results: List[Expr]):
+ if isinstance(s, kind):
+ results.append(s)
+ Filter(s.lhs, kind, results)
+ Filter(s.rhs, kind, results)
+
+@Filter.register
+def _(v: Deref, kind:type, results: List[Expr]):
+ if isinstance(v, kind):
+ results.append(v)
+ Filter(v.x, kind, results)
+
+@Filter.register
+def _(i: Index, kind: type, results: List[Expr]):
+ if isinstance(i, kind):
+ results.append(i)
+ Filter(i.x, kind, results)
+ Filter(i.i, kind, results)
+
+@Filter.register
+def _(comma: Comma, kind: type, results: List[Expr]):
+ if isinstance(comma, kind):
+ results.append(comma)
+ Filter(comma.expr[0], kind, results)
+ Filter(comma.expr[1], kind, results)
+
+@Filter.register
+def _(op: BinaryOp, kind: type, results: List[Expr]):
+ if isinstance(op, kind):
+ results.append(op)
+ Filter(op.l, kind, results)
+ Filter(op.r, kind, results)
+
+def VarsUsed(stmt: Stmt) -> Set[Var]:
+ vars = []
+ Visit(loop.body, lambda node: Filter(node, Ident, vars))
+ vars = set([v.var for v in vars])
-@VarsAccessed.register
-def _(op: BinaryOp, vars):
- VarsAccessed(op.l, vars)
- VarsAccessed(op.r, vars)
+ return vars
# ------------------------------------------
# Large scale functions
@@ -780,9 +802,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
# pull off needed features of the loop
it = loop.init.lhs.var
- vars = []
- Visit(loop.body, lambda node: VarsAccessed(node, vars))
- print(vars)
+ vars = VarsUsed(loop.body)
params = [v for v in vars if type(v) == Param]
stacks = [v for v in vars if type(v) == Var]
@@ -813,7 +833,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
if isa != SIMD.AVX2:
raise ValueError(f"ISA '{isa}' not currently implemented")
- vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params, func.vars)
+ vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params)
for stmt in func.stmts:
if IsLoop(stmt):
loop = For(stmt.init, stmt.cond, stmt.step)