From 4e4a3ae1b19611f0367624f541efe91aff570fab Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Tue, 12 May 2020 11:12:58 -0700 Subject: fix: some type errors on the AST --- lib/c.py | 62 +++++++++++++++++++++++++++++++++++++------------------------- 1 file changed, 37 insertions(+), 25 deletions(-) (limited to 'lib') diff --git a/lib/c.py b/lib/c.py index c963b5b..d88f7ab 100644 --- a/lib/c.py +++ b/lib/c.py @@ -538,20 +538,16 @@ class Stmt(Emitter): emit(";") class Empty(Stmt): - def __init__(self): - pass - def emit(self): - super(Empty, self).emit() + pass class Block(Stmt): - def __init__(self, *stmts: List[Stmt]): - self.stmts = list(stmts) + def __init__(self, *stmts: List[Stmt | Expr]): + self.stmts = [ s if isinstance(s, Stmt) else StmtExpr(s) for s in stmts ] def emit(self): enter_scope() for i, stmt in enumerate(self.stmts): stmt.emit() - emit(";") if i < numel(self.stmts) - 1: emitln() @@ -605,7 +601,7 @@ class For(Stmt): self.body = stmt return elif not isinstance(self.body, Block): - self.body = Block([self.body]) + self.body = Block(self.body) self.body.stmts.append(stmt) class Return(Stmt): @@ -764,15 +760,21 @@ def Swap(x: Var, y: Var, tmp: Var) -> Stmt: def IsLoop(s: Stmt) -> bool: return type(s) == For -def Expand(s: StmtExpr, times: int) -> Block: - if not isinstance(s, StmtExpr): - raise TypeError(f"{type(x)} not supported by Expand operation") - - return Block(*[StmtExpr(Step(s.x, i)) for i in range(times)]) - def EvenTo(x: Var, n: int) -> Var: return And(x, Negate(I(n-1))) +# ------------------------------------------ +# Expand: takes statement, indexed, and expands it x times + +# def Expand(s: Stmt, times: int) -> Block: +# if not isinstance(s, StmtExpr): +# raise TypeError(f"{type(x)} not supported by Expand operation") + +# return Block(StmtExpr(Step(s.x, i)) for i in range(times)) + +# ------------------------------------------ +# repeat: takes command on an array and repeats it + @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: raise TypeError(f"{type(x)} not supported by Repeat operation") @@ -805,7 +807,9 @@ def _(x: Dec, times: int): def _(x: Comma, times: int): return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) -# TODO: Parameterize the variables that should not be expanded +# ------------------------------------------ +# step: indexes an expression by i + @singledispatch def Step(x: Expr, i: int) -> Expr: raise TypeError(f"{type(x)} not supported by Step operation") @@ -852,21 +856,21 @@ def _(s: Empty, func): @Visit.register def _(blk: Block, func): for stmt in blk.stmts: - func(stmt) + Visit(stmt, func) @Visit.register def _(jmp: If, func): func(jmp.cond) - func(jmp.then) + Visit(jmp.then, func) if jmp.orelse is not None: - func(jmp.orelse) + Visit(jmp.orelse, func) @Visit.register def _(loop: For, func): func(loop.init) func(loop.cond) func(loop.step) - func(loop.body) + Visit(loop.body, func) @Visit.register def _(ret: Return, func): @@ -889,7 +893,7 @@ def _(s: Empty, func): @Make.register def _(blk: Block, func): - return Block(*[func(stmt) for stmt in blk.Stmts]) + return Block(*(Make(stmt, func) for stmt in blk.stmts)) @Make.register def _(loop: For, func): @@ -897,7 +901,7 @@ def _(loop: For, func): @Make.register def _(jmp: If, func): - return If(func(jmp.cond), func(jmp.then), func(loop.orelse) if loop.orelse is not None else None) + return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func) if jmp.orelse is not None else None) @Make.register def _(ret: Return, func): @@ -1204,7 +1208,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun kernel.execute( Set(n, EvenTo(n, times)), For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times), - body=Expand(loop.body, times) + body = Block(*(Make(loop.body, lambda x: Step(x, i)) for i in range(times))) ), Return(n) ) @@ -1347,9 +1351,7 @@ def axpby(): F.emit() print(buffer) -if __name__ == "__main__": - emitheader() - +def argmax(): F = Func("blas·argmax", Int, Params( (Int, "len"), (Ptr(Float64), "x"), @@ -1370,7 +1372,17 @@ if __name__ == "__main__": ) ) + loop, kernel, call = Unroll(loop, 8, "argmax") + kernel.emit() + emitln(2) + F.execute(loop) F.emit() print(buffer) + +if __name__ == "__main__": + emitheader() + + argmax() + -- cgit v1.2.1