diff options
author | Nicholas Noll <nbnoll@eml.cc> | 2020-05-11 14:40:40 -0700 |
---|---|---|
committer | Nicholas Noll <nbnoll@eml.cc> | 2020-05-11 14:40:40 -0700 |
commit | 45851d361a6dd81400cca97866526d2dabdc32dd (patch) | |
tree | 9a5d70a0ca206e9583f7ec5fa3c32e42bb8a8479 /lib | |
parent | fd50eff732d45ba2c4400fabf301c65c776f2338 (diff) |
refactor: changed search on expr nodes to a filter function
Diffstat (limited to 'lib')
-rw-r--r-- | lib/c.py | 104 |
1 files changed, 62 insertions, 42 deletions
@@ -5,7 +5,7 @@ import os import sys from enum import Enum -from typing import List, Tuple, Dict +from typing import Set, List, Tuple, Dict from functools import singledispatch numel = len @@ -700,7 +700,7 @@ def _(ix: Index, i: int): @singledispatch def Visit(s: Stmt, func): - raise TypeError(f"{type(s)} not supported by VarsAccessed operation") + raise TypeError(f"{type(s)} not supported by Visit operation") @Visit.register def _(s: Empty, func): @@ -730,45 +730,67 @@ def _(x: StmtExpr, func): # expression functions @singledispatch -def VarsAccessed(x: Expr, vars: List[Vars]): - raise TypeError(f"{type(s)} not supported by VarsAccessed operation") - -@VarsAccessed.register(Empty) -@VarsAccessed.register(Literal) -def _(x, vars): +def Filter(x: Expr, kind: type, results: List[Expr]): + raise TypeError(f"{type(s)} not supported by Filter operation") + +@Filter.register +def _(x: Empty, kind: type, results: List[Expr]): + if isinstance(x, kind): + results.append(x) + +@Filter.register(Ident) +@Filter.register(Literal) +def _(x, kind:type, results: List[Expr]): + if isinstance(x, kind): + results.append(x) return -@VarsAccessed.register -def _(op: UnaryOp, vars): - return VarsAccessed(op.x, vars) - -@VarsAccessed.register -def _(sym: Ident, vars): - vars.append(sym.var) - -@VarsAccessed.register -def _(s: Assign, vars): - VarsAccessed(s.lhs, vars) - VarsAccessed(s.rhs, vars) - -@VarsAccessed.register -def _(v: Deref, vars): - VarsAccessed(v.x, vars) - -@VarsAccessed.register -def _(i: Index, vars): - VarsAccessed(i.x, vars) - VarsAccessed(i.i, vars) - -@VarsAccessed.register -def _(comma: Comma, vars): - VarsAccessed(comma.expr[0], vars) - VarsAccessed(comma.expr[1], vars) +@Filter.register +def _(x: UnaryOp, kind:type, results: List[Expr]): + if isinstance(x, kind): + results.append(x) + Filter(op.x, kind, results) + +@Filter.register +def _(s: Assign, kind:type, results: List[Expr]): + if isinstance(s, kind): + results.append(s) + Filter(s.lhs, kind, results) + Filter(s.rhs, kind, results) + +@Filter.register +def _(v: Deref, kind:type, results: List[Expr]): + if isinstance(v, kind): + results.append(v) + Filter(v.x, kind, results) + +@Filter.register +def _(i: Index, kind: type, results: List[Expr]): + if isinstance(i, kind): + results.append(i) + Filter(i.x, kind, results) + Filter(i.i, kind, results) + +@Filter.register +def _(comma: Comma, kind: type, results: List[Expr]): + if isinstance(comma, kind): + results.append(comma) + Filter(comma.expr[0], kind, results) + Filter(comma.expr[1], kind, results) + +@Filter.register +def _(op: BinaryOp, kind: type, results: List[Expr]): + if isinstance(op, kind): + results.append(op) + Filter(op.l, kind, results) + Filter(op.r, kind, results) + +def VarsUsed(stmt: Stmt) -> Set[Var]: + vars = [] + Visit(loop.body, lambda node: Filter(node, Ident, vars)) + vars = set([v.var for v in vars]) -@VarsAccessed.register -def _(op: BinaryOp, vars): - VarsAccessed(op.l, vars) - VarsAccessed(op.r, vars) + return vars # ------------------------------------------ # Large scale functions @@ -780,9 +802,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun # pull off needed features of the loop it = loop.init.lhs.var - vars = [] - Visit(loop.body, lambda node: VarsAccessed(node, vars)) - print(vars) + vars = VarsUsed(loop.body) params = [v for v in vars if type(v) == Param] stacks = [v for v in vars if type(v) == Var] @@ -813,7 +833,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func: if isa != SIMD.AVX2: raise ValueError(f"ISA '{isa}' not currently implemented") - vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params, func.vars) + vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) for stmt in func.stmts: if IsLoop(stmt): loop = For(stmt.init, stmt.cond, stmt.step) |