From 9a7c2b4f12b976ee080412c3b685fb49e3d20d34 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Tue, 12 May 2020 15:34:15 -0700 Subject: checkpoint: plan to simplify vectorization code --- lib/c.py | 139 ++++++++++++++++++++++++++++++++++++++++++++++++++------------- 1 file changed, 110 insertions(+), 29 deletions(-) (limited to 'lib') diff --git a/lib/c.py b/lib/c.py index d88f7ab..25c02eb 100644 --- a/lib/c.py +++ b/lib/c.py @@ -7,7 +7,7 @@ import sys import sympy from enum import Enum -from typing import List, Tuple, Dict +from typing import Callable, List, Tuple, Dict from functools import singledispatch numel = len @@ -249,7 +249,7 @@ class Expr(Emitter): pass # Literals -class Literal(): +class Literal(Expr): def emit(self): emit(f"{self}") @@ -738,10 +738,18 @@ class Var(Decl): def emitspec(self): self.type.emitspec(self.name) + @classmethod + def copy(cls, other: Var): + return cls(other.type, other.name, other.storage) + class Param(Var): def __init__(self, type: Type, name: str): return super(Param, self).__init__(type, name, Mem.Auto) + @classmethod + def copy(cls, other: Param): + return cls(other.type, other.name) + def Params(*ps: List[Tuple(Type, str)]) -> List[Param]: return [Param(p[0], p[1]) for p in ps] @@ -897,11 +905,14 @@ def _(blk: Block, func): @Make.register def _(loop: For, func): - return For(func(loop.init), func(loop.cond), func(loop.step), func(loop.body)) + return For(func(loop.init), func(loop.cond), func(loop.step), Make(loop.body, func)) @Make.register def _(jmp: If, func): - return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func) if jmp.orelse is not None else None) + if jmp.orelse is not None: + return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func)) + else: + return If(func(jmp.cond), Make(jmp.then, func)) @Make.register def _(ret: Return, func): @@ -1002,7 +1013,7 @@ def _(x: Call): # Returns a brand new tree structure @singledispatch -def Transform(x: Expr, func): +def Transform(x: Expr, func: Callable[Expr, Expr]): raise TypeError(f"{type(x)} not supported by Transform operation") @Transform.register @@ -1192,29 +1203,59 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun i = loop.init.lhs.var vars = VarsUsed(loop.body) - params = [v for v in vars if type(v) == Param] - stacks = [v for v in vars if type(v) == Var] + def asvector(v): + if v != i: + return Var(Array(v.type, times), v.name, v.storage) + else: + return Var.copy(v) + + param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if type(v) == Param] + stack = {v.name: asvector(v) for v in vars if type(v) == Var} # TODO: More sophisticated type checking n = loop.cond.r.var if (type(n) != Param): raise TypeError(f"{type(n)} not implemented yet") - params = [n] + params - kernel = Func(f"{name}_kernel{times}", Int, params, stacks) - i = kernel.variables("i") + kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values())) body = loop.body + itor, itorp = kernel.variables("i", "ip") + + def mkarray(x: Expr, times: int): + if isinstance(x, Ident): + if type(x.var) == Var: + x.var = stack[x.name] + return x + + def step(x: Expr, times: int): + if x == itor: + return Add(x, I(times)) + if isinstance(x, Ident): + if type(x.var) == Var: + return Index(x, I(times)) + if not isinstance(x, Expr): + raise ValueError(f"panic, hit type {type(x)}") + + return x kernel.execute( Set(n, EvenTo(n, times)), - For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times), - body = Block(*(Make(loop.body, lambda x: Step(x, i)) for i in range(times))) + For(Set(itor, I(0)), LT(itor, n), Repeat(loop.step, times), + body = Block(* + (Make( + Make(loop.body, + lambda node: + Transform(node, lambda x: mkarray(x, times))), + lambda node: + Transform(node, lambda x: step(x, i))) for i in range(times) + ) + ) ), - Return(n) + Set(Deref(itorp), itor) ) loop.init = None - return loop, kernel, Call(kernel, params) + return loop, kernel, Call(kernel, param) # Replaces all vectorizable loops inside Func with vectorized variants # Returns the new vectorized function (tagged by the SIMD chosen) @@ -1225,7 +1266,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func: vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) for stmt in func.stmts: if IsLoop(stmt): - loop = For(stmt.init, stmt.cond, stmt.step) iterator = set([stmt.init.lhs]) @@ -1240,8 +1280,12 @@ def Vectorize(func: Func, isa: SIMD) -> Func: if numel(instr) % 4 != 0: raise ValueError("loop can not be vectorized, instructions can not be globbed equally") + class symtab(object): + scalar = {} + stack = {} + addrs = {} + # TODO: Allow for non-sequential accesses? - s_symtab, v_symtab = {}, {} for i in range(0, numel(instr), 4): scalars = [] vectors = [] @@ -1266,26 +1310,31 @@ def Vectorize(func: Func, isa: SIMD) -> Func: if i == 0: for s in scalars[0]: if type(s.var) == Param: - intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) - s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) + intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) + symtab.scalar[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) vfunc.execute(Set(intermediate, Ref(s))) - vfunc.execute(Set(s_symtab[s.name], intermediate)) + vfunc.execute(Set(symtab.scalar[s.name], intermediate)) else: - s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) + symtab.scalar[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) for v in vectors[0].keys(): - v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256") + if type(v.var) == Param: + symtab.addrs[v.name] = Var(Ptr(Float64x4), f"{v.name}256") + else: + symtab.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256")) # IMPORTANT: We do a post-order traversal. # We transforms leaves (identifiers) first and then move back up the root def translate(x): if isinstance(x, Ident): - if x.name in s_symtab: - return s_symtab[x.name] - elif x.name in v_symtab: - x.var = v_symtab[x.name] + if x.name in symtab.scalar: + return symtab.scalar[x.name] + elif x.name in symtab.stack: + return symtab.stack[x.name] + elif x.name in symtab.addrs: + x.var = symtab.addrs[x.name] if isinstance(x, Index): - if type(x.x) == Ident and x.x.name in v_symtab: + if type(x.x) == Ident and (x.x.name in symtab.addrs or x.x.name in symtab.stack): return Add(x.x, x.i) if isinstance(x, BinaryOp): l, r = GetType(x.l), GetType(x.r) @@ -1312,6 +1361,37 @@ def Strided(func: Func) -> Func: # ------------------------------------------------------------------------ # Point of testing +def copy(): + F = Func("blas·copy", Void, + Params( + (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") + ), + Vars( + (Int, "i"), (Float64, "tmp") + ) + ) + # could also declare like + # F.declare( ... ) + + len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Swap(Index(x, i), Index(y, i), tmp) + ) + + rem, kernel, call = Unroll(loop, 8, "copy") + kernel.emit() + emitln(2) + + avx256kernel = Vectorize(kernel, SIMD.AVX2) + avx256kernel.emit() + emitln(2) + + F.execute(Set(i, call), rem) + F.emit() + print(buffer) + def axpby(): F = Func("blas·axpby", Void, Params( @@ -1328,7 +1408,6 @@ def axpby(): len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) loop.execute( Set(Index(y, i), Mul(b, @@ -1339,7 +1418,7 @@ def axpby(): ) ) - rem, kernel, call = Unroll(loop, 16, "axpy") + rem, kernel, call = Unroll(loop, 8, "axpby") kernel.emit() emitln(2) @@ -1384,5 +1463,7 @@ def argmax(): if __name__ == "__main__": emitheader() - argmax() + copy() + # axpby() + # argmax() -- cgit v1.2.1