From d3241acc69327081c2f9c2b1d9ed4ae96d8f1287 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Tue, 12 May 2020 18:38:59 -0700 Subject: feat: allow for cleanup of end of vectorization functions --- lib/c.py | 94 +++++++++++++++++++++++++++++++++++++++++----------------------- 1 file changed, 61 insertions(+), 33 deletions(-) (limited to 'lib/c.py') diff --git a/lib/c.py b/lib/c.py index 08203cb..762bd0b 100644 --- a/lib/c.py +++ b/lib/c.py @@ -689,6 +689,7 @@ class Func(Decl): self.stmts[-1].emit() exits_scope() + emitln(2) def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: if var.name in [v.name for v in self.vars]: @@ -978,7 +979,8 @@ def _(x: Index): elif isinstance(base, Array): return base.base else: - raise TypeError(f"attempting to index type {base}") + pass + # raise TypeError(f"attempting to index type {base} of node {x.x}") # TODO: type checking for both @@ -1206,13 +1208,12 @@ def AddrAccessed(stmt: Stmt): if vars[0] in scalars: scalars.remove(vars[0]) - print([(key.name, val) for key, val in vectors.items()]) return set(scalars), vectors # ------------------------------------------ # Large scale functions -def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call): +def Unroll(name: str, loop: For, times: int, ret: Ident = None) -> (Func, List[Stmt]): # TODO: More sophisticated computation for length of loop if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") @@ -1227,16 +1228,22 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun else: return Var.copy(v) - param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if type(v) == Param] + n = loop.cond.r.var + param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if (type(v) == Param and v != n)] stack = {v.name: asvector(v) for v in vars if type(v) == Var} # TODO: More sophisticated type checking - n = loop.cond.r.var if (type(n) != Param): raise TypeError(f"{type(n)} not implemented yet") - kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values())) - body = loop.body + if ret is None: + kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values())) + else: + kernel = Func(f"{name}_kernel{times}", ret.var.type, param, list(stack.values())) + + len = kernel.declare(n) + + body = loop.body itor, itorp = kernel.variables("i", "ip") def mkarray(x: Expr, times: int): @@ -1256,24 +1263,42 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun return x + expandedloop = Make(loop.body, lambda node: Transform(node, lambda x: mkarray(x, times))) kernel.execute( - Set(n, EvenTo(n, times)), - For(Set(itor, I(0)), LT(itor, n), Repeat(loop.step, times), + Set(len, EvenTo(Deref(itorp), times)), + For(Set(itor, I(0)), LT(itor, len), Repeat(loop.step, times), body = Block(* - (Make( - Make(loop.body, - lambda node: - Transform(node, lambda x: mkarray(x, times))), - lambda node: - Transform(node, lambda x: step(x, i))) for i in range(times) + (Make(expandedloop, lambda node: + Transform(node, lambda x: step(x, i))) for i in range(times) ) ) ), Set(Deref(itorp), itor) ) + if ret is not None: + # NOTE: This is probably not general... + def index(node): + if node == ret: + return Index(node, I(0)) + if node == itor: + return Index(ret, itor) + return node + + k = kernel.variables(ret.name) + if k.var.type != ret.var.type: + resolve = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor), + body = Make(expandedloop, lambda node: Transform(node, index)) + ) + kernel.execute(resolve) + + kernel.execute(Return(Index(ret, I(0)))) + loop.init = None - return loop, kernel, Call(kernel, param) + if ret is None: + return kernel, [Set(itor, len), Call(kernel, [Ref(itor)] + param[1:]), loop] + else: + return kernel, [Set(itor, len), Set(ret, Call(kernel, [Ref(itor)] + param[1:])), loop] # Replaces all vectorizable loops inside Func with vectorized variants # Returns the new vectorized function (tagged by the SIMD chosen) @@ -1331,6 +1356,12 @@ def Vectorize(func: Func, isa: SIMD) -> Func: vecs = vectors[0] if i == 0: syms = SymTab() + for s in scalars[0]: + intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) + syms.stack[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) + vfunc.execute(Set(intermediate, Ref(s))) + vfunc.execute(Set(syms.stack[s.name], intermediate)) + for v in vecs.keys(): # All params are treated AS addresses to load into vectorized registers if type(v.var) == Param: @@ -1344,6 +1375,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func: else: syms.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256")) + # IMPORTANT: We do a post-order traversal. # We transforms leaves (identifiers) first and then move back up the root def translate(x): @@ -1356,7 +1388,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func: if type(x.x) == Ident: if x.x.name in syms.addrs: return Add(x.x, x.i) - elif f"{x.x.name[:-3]}" in syms.stack: + elif f"{x.x.name[:-3]}" in syms.stack: #NOTE: This is hacky. Think of something better return Index(x.x, I(Eval(Div(x.i, 4)))) if isinstance(x, BinaryOp): l, r = GetType(x.l), GetType(x.r) @@ -1402,17 +1434,14 @@ def copy(): Swap(Index(x, i), Index(y, i), tmp) ) - rem, kernel, call = Unroll(loop, 8, "copy") + kernel, calls = Unroll("copy", loop, 8) kernel.emit() - emitln(2) avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() - emitln(2) - F.execute(Set(i, call), rem) + F.execute(*calls) F.emit() - print(buffer) def axpby(): F = Func("blas·axpby", Void, @@ -1440,17 +1469,14 @@ def axpby(): ) ) - rem, kernel, call = Unroll(loop, 8, "axpby") + kernel, calls = Unroll("axpby", loop, 8) kernel.emit() - emitln(2) avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() - emitln(2) - F.execute(Set(i, call), rem) + F.execute(*calls) F.emit() - print(buffer) def argmax(): F = Func("blas·argmax", Int, @@ -1473,19 +1499,21 @@ def argmax(): ) ) - loop, kernel, call = Unroll(loop, 8, "argmax") + kernel, calls = Unroll("argmax", loop, 8, ix) kernel.emit() - emitln(2) - F.execute(loop) + avx256kernel = Vectorize(kernel, SIMD.AVX2) + avx256kernel.emit() + + F.execute(*calls) F.emit() - print(buffer) if __name__ == "__main__": emitheader() - copy() + # copy() # axpby() - # argmax() + argmax() + print(buffer) -- cgit v1.2.1