From c9d4b2d7dd1d9a46571e5d2b2cf6ce10a9d9ebea Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Wed, 13 May 2020 08:29:16 -0700 Subject: unrolling blas level 1 fully works --- lib/c.py | 98 +++++++++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 79 insertions(+), 19 deletions(-) (limited to 'lib') diff --git a/lib/c.py b/lib/c.py index 762bd0b..bbc9ce3 100644 --- a/lib/c.py +++ b/lib/c.py @@ -1213,7 +1213,7 @@ def AddrAccessed(stmt: Stmt): # ------------------------------------------ # Large scale functions -def Unroll(name: str, loop: For, times: int, ret: Ident = None) -> (Func, List[Stmt]): +def Unroll(name: str, loop: For, times: int, ret: Ident = None, accumulator = None) -> (Func, List[Stmt]): # TODO: More sophisticated computation for length of loop if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") @@ -1277,20 +1277,15 @@ def Unroll(name: str, loop: For, times: int, ret: Ident = None) -> (Func, List[S ) if ret is not None: - # NOTE: This is probably not general... - def index(node): - if node == ret: - return Index(node, I(0)) - if node == itor: - return Index(ret, itor) - return node - k = kernel.variables(ret.name) if k.var.type != ret.var.type: - resolve = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor), - body = Make(expandedloop, lambda node: Transform(node, index)) - ) - kernel.execute(resolve) + if accumulator is None: + raise ValueError("If loop returns a value, an accumulator must be given") + else: + accumulator = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor), + body = accumulator(I(0), itor, k) + ) + kernel.execute(accumulator) kernel.execute(Return(Index(ret, I(0)))) @@ -1323,8 +1318,10 @@ def Vectorize(func: Func, isa: SIMD) -> Func: body = stmt.body if type(body) != Block: - print("could not vectorize loop, skipping") - loop.body = body + # TODO: Think through this carefully... + # This is coded for the accumulation step at the end of a kernel + loop.cond.r = I(Eval(Div(loop.cond.r, 4))) + loop.body = body else: instr = body.stmts # TODO: Remove hardcoded 4 -> should be function of types! @@ -1494,26 +1491,89 @@ def argmax(): If(GT(Index(x, i), max), Block( Set(ix, i), - Set(max, Index(x, i)), + Set(max, Index(x, ix)), ) ) ) + kernel, calls = Unroll("argmax", loop, 8, ix, + lambda ires, icur, node: + If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))), + Block(Set(Index(node, ires), Index(node, icur))) + ) + ) + kernel.emit() + + # avx256kernel = Vectorize(kernel, SIMD.AVX2) + # avx256kernel.emit() + + F.execute(*calls, Return(ix)) + + F.execute(loop) + F.emit() + +def dot(): + F = Func("blas·dot", Float64, + Params( + (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") + ), + Vars( + (Int, "i"), (Float64, "sum"), + ) + ) + len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + AddSet(sum, Mul(Index(x, i), Index(y, i))) + ) - kernel, calls = Unroll("argmax", loop, 8, ix) + kernel, calls = Unroll("dot", loop, 16, sum, + lambda ires, icur, node: + StmtExpr(AddSet(Index(node, ires), Index(node, icur))) + ) kernel.emit() avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() - F.execute(*calls) + F.execute(*calls, Return(sum)) F.emit() +def gemv(): + F = Func("blas·gemv", Void, + Params( + (Int, "nrow"), (Int, "ncol"), (Float64, "a"), (Ptr(Float64), "m"), (Int, "incm"), + (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") + ), + Vars( + (Int, "r"), (Int, "c"), (Ptr(Float64), "row"), (Float64, "res") + ) + ) + r, c, row, res = F.variables("r", "c", "row", "res") + nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") + loop = For(Set(r, I(0)), LT(r, nrow), Inc(r), + Block( + Set(row, Add(m, incm)), + Set(res, I(0)), + For(Set(c, I(0)), LT(c, ncol), Inc(c), + AddSet(res, Mul(Index(row, c), Index(x, c))) + ), + Set(Index(y, r), Add(Mul(a, res), Mul(b, Index(y, r)))) + ) + ) + + F.execute(loop) + F.emit() + + if __name__ == "__main__": emitheader() + gemv() + # dot() + # argmax() # copy() # axpby() - argmax() print(buffer) -- cgit v1.2.1