From fd50eff732d45ba2c4400fabf301c65c776f2338 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Mon, 11 May 2020 14:17:11 -0700 Subject: refactor: pulled out statement BFS code from expr eval code --- lib/c.py | 194 ++++++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 118 insertions(+), 76 deletions(-) (limited to 'lib/c.py') diff --git a/lib/c.py b/lib/c.py index 1554cc4..518952c 100644 --- a/lib/c.py +++ b/lib/c.py @@ -620,85 +620,39 @@ def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: # ------------------------------------------------------------------------ # AST modification/production functions +# ------------------------------------------ +# basic (non-recursive) commands + def Swap(x: Var, y: Var, tmp: Var) -> Stmt: return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) def IsLoop(s: Stmt) -> bool: return type(s) == For -# TODO: Generalize to memory addresses! -# This will allow better vectorization -@singledispatch -def VarsUsed(s: object) -> List[Vars]: - raise TypeError(f"{type(s)} not supported by VarsUsed operation") - -@VarsUsed.register -def _(op: UnaryOp): - return VarsUsed(op.x) - -@VarsUsed.register -def _(sym: Ident): - return [sym.var] - -@VarsUsed.register -def _(s: Assign): - vars = [] - vars.extend(VarsUsed(s.lhs)) - vars.extend(VarsUsed(s.rhs)) - return vars - -@VarsUsed.register -def _(lit: Literal): - return [] - -@VarsUsed.register -def _(i: Index): - vars = [] - vars.extend(VarsUsed(i.x)) - vars.extend(VarsUsed(i.i)) - return vars - -@VarsUsed.register -def _(comma: Comma): - vars = [] - vars.extend(VarsUsed(comma.expr[0])) - vars.extend(VarsUsed(comma.expr[1])) - return vars +def Expand(s: StmtExpr, times: int) -> Block: + if not isinstance(s, StmtExpr): + raise TypeError(f"{type(x)} not supported by Expand operation") -@VarsUsed.register -def _(op: BinaryOp): - vars = [] - vars.extend(VarsUsed(op.l)) - vars.extend(VarsUsed(op.r)) - return vars + return Block([StmtExpr(Step(s.x, i)) for i in range(times)]) -@VarsUsed.register -def _(s: Empty): - return [] +def EvenTo(x: Var, n: int) -> Var: + return And(x, Negate(I(n-1))) -@VarsUsed.register -def _(blk: Block): - vars = [] - for stmt in blk.stmts: - vars.extend(VarsUsed(stmt)) - return vars +@singledispatch +def Repeat(x: Expr, times: int) -> Expr | List[Expr]: + raise TypeError(f"{type(x)} not supported by Repeat operation") -@VarsUsed.register -def _(loop: For): - vars = [] - vars.extend(VarsUsed(loop.init)) - vars.extend(VarsUsed(loop.cond)) - vars.extend(VarsUsed(loop.step)) - vars.extend(VarsUsed(loop.body)) - return vars +@Repeat.register +def _(x: Inc, times: int): + return AddSet(x.x, I(times)) -@VarsUsed.register -def _(ret: Return): - return VarsUsed(ret.val) +@Repeat.register +def _(x: Dec, times: int): + return DecSet(x.x, I(times)) -@VarsUsed.register -def _(x: StmtExpr): - return VarsUsed(x.x) +@Repeat.register +def _(x: Comma, times: int): + return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: @@ -741,14 +695,83 @@ def _(x: Deref, i: int): def _(ix: Index, i: int): return Index(ix.x, I(ix.i + i)) -def Expand(s: StmtExpr, times: int) -> Block: - if not isinstance(s, StmtExpr): - raise TypeError(f"{type(x)} not supported by Expand operation") +# ------------------------------------------ +# bfs search on statements in ast - return Block([StmtExpr(Step(s.x, i)) for i in range(times)]) +@singledispatch +def Visit(s: Stmt, func): + raise TypeError(f"{type(s)} not supported by VarsAccessed operation") -def EvenTo(x: Var, n: int) -> Var: - return And(x, Negate(I(n-1))) +@Visit.register +def _(s: Empty, func): + return + +@Visit.register +def _(blk: Block, func): + for stmt in blk.stmts: + func(stmt) + +@Visit.register +def _(loop: For, func): + func(loop.init) + func(loop.cond) + func(loop.step) + func(loop.body) + +@Visit.register +def _(ret: Return, func): + func(ret.val) + +@Visit.register +def _(x: StmtExpr, func): + func(x.x) + +# ------------------------------------------ +# expression functions + +@singledispatch +def VarsAccessed(x: Expr, vars: List[Vars]): + raise TypeError(f"{type(s)} not supported by VarsAccessed operation") + +@VarsAccessed.register(Empty) +@VarsAccessed.register(Literal) +def _(x, vars): + return + +@VarsAccessed.register +def _(op: UnaryOp, vars): + return VarsAccessed(op.x, vars) + +@VarsAccessed.register +def _(sym: Ident, vars): + vars.append(sym.var) + +@VarsAccessed.register +def _(s: Assign, vars): + VarsAccessed(s.lhs, vars) + VarsAccessed(s.rhs, vars) + +@VarsAccessed.register +def _(v: Deref, vars): + VarsAccessed(v.x, vars) + +@VarsAccessed.register +def _(i: Index, vars): + VarsAccessed(i.x, vars) + VarsAccessed(i.i, vars) + +@VarsAccessed.register +def _(comma: Comma, vars): + VarsAccessed(comma.expr[0], vars) + VarsAccessed(comma.expr[1], vars) + +@VarsAccessed.register +def _(op: BinaryOp, vars): + VarsAccessed(op.l, vars) + VarsAccessed(op.r, vars) + +# ------------------------------------------ +# Large scale functions def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call): # TODO: More sophisticated computation for length of loop @@ -757,7 +780,9 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun # pull off needed features of the loop it = loop.init.lhs.var - vars = set(VarsUsed(loop.body)) + vars = [] + Visit(loop.body, lambda node: VarsAccessed(node, vars)) + print(vars) params = [v for v in vars if type(v) == Param] stacks = [v for v in vars if type(v) == Var] @@ -784,7 +809,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun # Replaces all vectorizable loops inside Func with vectorized variants # Returns the new vectorized function (tagged by the SIMD chosen) -def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func: +def Vectorize(func: Func, isa: SIMD) -> Func: if isa != SIMD.AVX2: raise ValueError(f"ISA '{isa}' not currently implemented") @@ -792,12 +817,29 @@ def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func: for stmt in func.stmts: if IsLoop(stmt): loop = For(stmt.init, stmt.cond, stmt.step) + body = stmt.body + if type(body) != Block: + print("could not vectorize loop, skipping") + loop.body = body + else: + instrs = body.stmts + # TODO: Remove hardcoded 4 + if numel(instrs) % 4 != 0: + raise ValueError("loop can not be vectorized, instructions can not be globbed equally") + # TODO: Allow for non-sequential accesses? + # for i in range(numel(instrs)/4): + + vfunc.execute(loop) else: vfunc.execute(stmt) return vfunc + +def Strided(func: Func) -> Func: + pass + # ------------------------------------------------------------------------ # Point of testing @@ -826,7 +868,7 @@ if __name__ == "__main__": kernel.emit() emitln(2) - avx256kernel = Vectorize(kernel, [], SIMD.AVX2) + avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() emitln(2) -- cgit v1.2.1