From b45f53b681d7bef4f1e96ee27f80c40dbf67573d Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Mon, 11 May 2020 18:50:05 -0700 Subject: fix: made pattern of loads as generalized derefs more obvious --- lib/c.py | 170 ++++++++++++++++++++++++++++++++++++++++++++------------------- 1 file changed, 118 insertions(+), 52 deletions(-) (limited to 'lib/c.py') diff --git a/lib/c.py b/lib/c.py index 070ade6..f85e7a7 100644 --- a/lib/c.py +++ b/lib/c.py @@ -148,7 +148,7 @@ Float32x8 = Base("__mm256") Float64x2 = Base("__m128d") Float64x4 = Base("__mm256d") -def VectorType(kind): +def IsVectorType(kind): if kind is Float32x4 or \ kind is Float32x8 or \ kind is Float64x2 or \ @@ -158,6 +158,10 @@ def VectorType(kind): kind is Int64x2 or \ kind is Int64x4: return True + + if type(kind) == Ptr: + return IsVectorType(kind.to) + return False # TODO: Make this ARCH dependent @@ -196,6 +200,46 @@ SIMDSupport = { SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), } +# TODO: Think of a better way to handle this +def emit·load128(l, r): + if l is not None: + l.emit() + emit(" = ") + emit("_mm_loadu_sd(") + r.emit() + emit(")") + +def emit·broadcast256(l, r): + l.emit() + emit(" = _mm256_broadcastsd_pd(") + r.emit() + emit(")") + +def emit·load256(l, r): + if l is not None: + l.emit() + emit(" = ") + emit("_mm256_loadu_pd(") + r.emit() + emit(")") + +def emit·store256(l, r): + emit("_mm256_storeu_pd(") + l.emit() + emit(", ") + r.emit() + emit(")") + +def emit·copy256(l, r): + emit("_mm256_storeu_pd(") + l.emit() + emit(", ") + emit("_mm256_loadu_pd(") + r.emit() + emit("))") + + + # TODO: Typedefs... # ------------------------------------------ @@ -239,9 +283,18 @@ class UnaryOp(Expr): pass class Deref(UnaryOp): + method = { + Ptr(Float64x4) : lambda x: emit·load256(None, x), + } + def emit(self): - emit("*") - self.x.emit() + kind = GetType(self.x) + if kind in self.method: + print("emitting") + self.method[kind](self.x) + else: + emit("*") + self.x.emit() class Negate(UnaryOp): def emit(self): @@ -375,40 +428,19 @@ class Assign(Expr): self.lhs = lhs self.rhs = rhs -def emit·load256(l, r): - emit("_mm256_loadu_pd(") - l.emit() - emit(", ") - r.emit() - emit(")") - -def emit·store256(l, r): - emit("_mm256_storeu_pd(") - l.emit() - emit(", ") - r.emit() - emit(")") - -def emit·copy256(l, r): - emit("_mm256_storeu_pd(") - l.emit() - emit(", ") - emit("_mm256_loadu_pd(") - r.emit() - emit("))") - -SetTypeDispatch = { - (Float64x4, Ptr(Float64x4)) : emit·load256, - (Ptr(Float64x4), Float64x4) : emit·store256, - (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256, -} - class Set(Assign): + method = { + (Float64x2, Ptr(Float64)) : emit·load128, + (Float64x4, Float64x2) : emit·broadcast256, + (Float64x4, Ptr(Float64x4)) : emit·load256, + (Ptr(Float64x4), Float64x4) : emit·store256, + (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256, + } def emit(self): lhs = GetType(self.lhs) rhs = GetType(self.rhs) - if (lhs, rhs) in SetTypeDispatch: - SetTypeDispatch[(lhs, rhs)](self.lhs, self.rhs) + if (lhs, rhs) in self.method: + self.method[(lhs, rhs)](self.lhs, self.rhs) else: self.lhs.emit() emit(f" = ") @@ -416,9 +448,14 @@ class Set(Assign): class AddSet(Assign): def emit(self): - self.lhs.emit() - emit(f" += ") - self.rhs.emit() + lhs = GetType(self.lhs) + rhs = GetType(self.rhs) + if (lhs, rhs) in Set.method: + Set(self.lhs, Add(self.lhs, self.rhs)).emit() + else: + self.lhs.emit() + emit(f" += ") + self.rhs.emit() class SubSet(Assign): def emit(self): @@ -759,6 +796,14 @@ def _(x: Comma, i: int): def _(x: Assign, i: int): return type(x)(Step(x.lhs, i), Step(x.rhs, i)) +@Step.register +def _(x: BinaryOp, i: int): + return type(x)(Step(x.l, i), Step(x.r, i)) + +@Step.register +def _(x: UnaryOp, i: int): + return type(x)(Step(x.x, i)) + @Step.register def _(x: Ident, i: int): return x @@ -834,6 +879,8 @@ def _(x: StmtExpr, func): @singledispatch def GetType(x: Expr) -> Type: + if x is None: + return raise TypeError(f"{type(x)} not supported by GetType operation") @GetType.register @@ -955,7 +1002,7 @@ def _(x: Comma, func): @Transform.register def _(x: BinaryOp, func): - return func(BinaryOp(Transform(x.l, func), Transform(x.r, func))) + return func(type(x)(Transform(x.l, func), Transform(x.r, func))) # ------------------------------------------ # Filter function @@ -1171,7 +1218,13 @@ def Vectorize(func: Func, isa: SIMD) -> Func: # Create symbol table if i == 0: for s in scalars[0]: - s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) + if type(s.var) == Param: + intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) + s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) + vfunc.execute(Set(intermediate, Ref(s))) + vfunc.execute(Set(s_symtab[s.name], intermediate)) + else: + s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) for v in vectors[0].keys(): v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256") @@ -1183,16 +1236,23 @@ def Vectorize(func: Func, isa: SIMD) -> Func: # loop.execute(load) # IMPORTANT: We do a post-order traversal. - # The identifier will be substituted first! + # We transforms leaves (identifiers) first and then move back up the root def translate(x): - if type(x) == Ident: + if isinstance(x, Ident): if x.name in s_symtab: return s_symtab[x.name] elif x.name in v_symtab: x.var = v_symtab[x.name] - if type(x) == Index: + if isinstance(x, Index): if type(x.x) == Ident and x.x.name in v_symtab: return Add(x.x, x.i) + if isinstance(x, BinaryOp): + l, r = GetType(x.l), GetType(x.r) + if IsVectorType(l) or IsVectorType(r): + if type(l) == Ptr: + x.l = Deref(x.l) + if type(r) == Ptr: + x.r = Deref(x.r) # if type(x) == Index and x.x in set(symtab.values()): # return x.x @@ -1216,23 +1276,29 @@ def Strided(func: Func) -> Func: if __name__ == "__main__": emitheader() - Rot = Func("blas·swap", Void, + F = Func("blas·axpy", Void, Params( - (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") + (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Ptr(Float64), "y") ), Vars( - (Int, "i"), (Float64, "tmp") + (Int, "i"), #(Float64, "tmp") ) ) # could also declare like - # Rot.declare( ... ) + # F.declare( ... ) # TODO: Increase ergonomics here... - len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)], - body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) - ) + len, x, y, i, a = F.variables("len", "x", "y", "i", "a") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)]) + # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) + loop.execute( + Set(Index(y, I(0)), + Add(Index(y, I(0)), + Mul(a, Index(x, I(0))) + ) + ) + ) rem, kernel, call = Unroll(loop, 8, "swap") kernel.emit() @@ -1242,8 +1308,8 @@ if __name__ == "__main__": avx256kernel.emit() emitln(2) - Rot.execute(Set(i, call), rem) + F.execute(Set(i, call), rem) - Rot.emit() + F.emit() print(buffer) -- cgit v1.2.1