From b4cfd0d257690aebd2df2f6ce00fb2158b3b61d5 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Mon, 11 May 2020 19:51:14 -0700 Subject: feat: added if statement --- lib/c.py | 112 ++++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 78 insertions(+), 34 deletions(-) diff --git a/lib/c.py b/lib/c.py index f1aab4f..0830dc5 100644 --- a/lib/c.py +++ b/lib/c.py @@ -240,8 +240,6 @@ def emit·copy256(l, r): r.emit() emit("))") - - # TODO: Typedefs... # ------------------------------------------ @@ -273,11 +271,18 @@ class Ident(Expr): self.name = var.name self.var = var + def __str__(self): + return str(self.name) + + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return type(self) == type(other) and self.name == other.name + def emit(self): emit(f"{self.name}") - def __str__(self): - return str(self.name) # Unary operators class UnaryOp(Expr): @@ -295,7 +300,6 @@ class Deref(UnaryOp): def emit(self): kind = GetType(self.x) if kind in self.method: - print("emitting") self.method[kind](self.x) else: emit("*") @@ -540,18 +544,33 @@ class Empty(Stmt): super(Empty, self).emit() class Block(Stmt): - def __init__(self, stmts: List[Stmt]): - self.stmts = stmts + def __init__(self, *stmts: List[Stmt]): + self.stmts = list(stmts) def emit(self): enter_scope() for i, stmt in enumerate(self.stmts): stmt.emit() + emit(";") if i < numel(self.stmts) - 1: emitln() exits_scope() +class If(Stmt): + def __init__(self, cond: Expr | List[Expr], then: Stmt | List[Stmt], orelse: Stmt | List[Stmt] | None = None): + self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) + self.then = Block(then) if isinstance(then, list) else then + self.orelse = Block(orelse) if isinstance(orelse, list) else orelse + + def emit(self): + emit("if (") + self.cond.emit() + emit(") ") + self.then.emit() + if self.orelse is not None: + self.orelse.emit() + class For(Stmt): def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None): self.init = init if isinstance(init, Expr) else ExprList(*init) @@ -700,6 +719,9 @@ class Func(Decl): def variables(self, *idents: List[str]) -> List[Expr]: vars = {v.name : v for v in self.vars + self.params} + + if numel(idents) == 1: + return Ident(vars[idents[0]]) return [Ident(vars[ident]) for ident in idents] class Var(Decl): @@ -746,7 +768,7 @@ def Expand(s: StmtExpr, times: int) -> Block: if not isinstance(s, StmtExpr): raise TypeError(f"{type(x)} not supported by Expand operation") - return Block([StmtExpr(Step(s.x, i)) for i in range(times)]) + return Block(*[StmtExpr(Step(s.x, i)) for i in range(times)]) def EvenTo(x: Var, n: int) -> Var: return And(x, Negate(I(n-1))) @@ -860,7 +882,7 @@ def _(s: Empty, func): @Make.register def _(blk: Block, func): - return Block([func(stmt) for stmt in blk.Stmts]) + return Block(*[func(stmt) for stmt in blk.Stmts]) @Make.register def _(loop: For, func): @@ -1010,7 +1032,7 @@ def _(x: BinaryOp, func): @singledispatch def Filter(x: Expr, cond, results: List[Expr]): - raise TypeError(f"{type(s)} not supported by Filter operation") + raise TypeError(f"{type(x)} not supported by Filter operation") @Filter.register def _(x: Empty, cond, results: List[Expr]): @@ -1141,7 +1163,7 @@ def AddrAccessed(stmt: Stmt): if vars[0] in scalars: scalars.remove(vars[0]) - return frozenset(scalars), vectors + return set(scalars), vectors # ------------------------------------------ # Large scale functions @@ -1152,7 +1174,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") # pull off needed features of the loop - it = loop.init.lhs.var + i = loop.init.lhs.var vars = VarsUsed(loop.body) params = [v for v in vars if type(v) == Param] @@ -1165,11 +1187,12 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun params = [n] + params kernel = Func(f"{name}_kernel{times}", Int, params, stacks) + i = kernel.variables("i") body = loop.body kernel.execute( Set(n, EvenTo(n, times)), - For(Set(it, I(0)), LT(it, n), Repeat(loop.step, times), + For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times), body=Expand(loop.body, times) ), Return(n) @@ -1187,7 +1210,10 @@ def Vectorize(func: Func, isa: SIMD) -> Func: vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) for stmt in func.stmts: if IsLoop(stmt): - loop = For(stmt.init, stmt.cond, stmt.step) + + loop = For(stmt.init, stmt.cond, stmt.step) + iterator = set([stmt.init.lhs]) + body = stmt.body if type(body) != Block: print("could not vectorize loop, skipping") @@ -1206,7 +1232,8 @@ def Vectorize(func: Func, isa: SIMD) -> Func: vectors = [] for j in range(4): s, v = AddrAccessed(instr[i+j]) - scalars.append(s) + s -= iterator # TODO: This is hacky + scalars.append(frozenset(s)) vectors.append(v) # Test if code in uniform to allow for vectorization @@ -1217,7 +1244,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func: for v, idx in vectors[j].items(): if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1: print(f"{delta}") - print(f"{sympy.simplify(delta)}") raise ValueError("non uniform vector accesses in consecutive line. can not vectorize") # If we made it to here, we have passed all checks. vectorize! @@ -1235,12 +1261,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func: for v in vectors[0].keys(): v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256") - # Necessary loads into registers - # NOTE: Generalization of a Deref - # for v in vectors[0].keys(): - # load = SIMDLoadAt(symtab[v.name], Add(v, I(i))) - # loop.execute(load) - # IMPORTANT: We do a post-order traversal. # We transforms leaves (identifiers) first and then move back up the root def translate(x): @@ -1260,8 +1280,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func: if type(r) == Ptr: x.r = Deref(x.r) - # if type(x) == Index and x.x in set(symtab.values()): - # return x.x return x loop.execute(Make(instr[i], lambda node: Transform(node, translate))) @@ -1279,12 +1297,10 @@ def Strided(func: Func) -> Func: # ------------------------------------------------------------------------ # Point of testing -if __name__ == "__main__": - emitheader() - - F = Func("blas·axpy", Void, +def axpby(): + F = Func("blas·axpby", Void, Params( - (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Ptr(Float64), "y") + (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") ), Vars( (Int, "i"), #(Float64, "tmp") @@ -1294,19 +1310,21 @@ if __name__ == "__main__": # F.declare( ... ) # TODO: Increase ergonomics here... - len, x, y, i, a = F.variables("len", "x", "y", "i", "a") + len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) loop.execute( Set(Index(y, i), - Add(Index(y, i), - Mul(a, Index(x, i)) + Mul(b, + Add(Index(y, i), + Mul(a, Index(x, i)), + ) ) ) ) - rem, kernel, call = Unroll(loop, 16, "swap") + rem, kernel, call = Unroll(loop, 16, "axpy") kernel.emit() emitln(2) @@ -1315,7 +1333,33 @@ if __name__ == "__main__": emitln(2) F.execute(Set(i, call), rem) - F.emit() + print(buffer) + +if __name__ == "__main__": + emitheader() + + F = Func("blas·argmax", Int, + Params( + (Int, "len"), (Ptr(Float64), "x"), + ), + Vars( + (Int, "i"), (Int, "ix"), (Float64, "max") + ) + ) + len, x, i, ix, max = F.variables("len", "x", "i", "ix", "max") + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + If(GT(Index(x, i), max), + Block( + Set(ix, i), + Set(max, Index(x, i)), + ) + ) + ) + + F.execute(loop) + + F.emit() print(buffer) -- cgit v1.2.1