aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-11 18:50:05 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-11 18:50:05 -0700
commitb45f53b681d7bef4f1e96ee27f80c40dbf67573d (patch)
treeb440c289f5c1c318b56705401db213256c85916d
parent2f4de9abb28ecb5af7390e7011a40f493c4b4dae (diff)
fix: made pattern of loads as generalized derefs more obvious
-rw-r--r--lib/c.py170
1 files changed, 118 insertions, 52 deletions
diff --git a/lib/c.py b/lib/c.py
index 070ade6..f85e7a7 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -148,7 +148,7 @@ Float32x8 = Base("__mm256")
Float64x2 = Base("__m128d")
Float64x4 = Base("__mm256d")
-def VectorType(kind):
+def IsVectorType(kind):
if kind is Float32x4 or \
kind is Float32x8 or \
kind is Float64x2 or \
@@ -158,6 +158,10 @@ def VectorType(kind):
kind is Int64x2 or \
kind is Int64x4:
return True
+
+ if type(kind) == Ptr:
+ return IsVectorType(kind.to)
+
return False
# TODO: Make this ARCH dependent
@@ -196,6 +200,46 @@ SIMDSupport = {
SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
}
+# TODO: Think of a better way to handle this
+def emit·load128(l, r):
+ if l is not None:
+ l.emit()
+ emit(" = ")
+ emit("_mm_loadu_sd(")
+ r.emit()
+ emit(")")
+
+def emit·broadcast256(l, r):
+ l.emit()
+ emit(" = _mm256_broadcastsd_pd(")
+ r.emit()
+ emit(")")
+
+def emit·load256(l, r):
+ if l is not None:
+ l.emit()
+ emit(" = ")
+ emit("_mm256_loadu_pd(")
+ r.emit()
+ emit(")")
+
+def emit·store256(l, r):
+ emit("_mm256_storeu_pd(")
+ l.emit()
+ emit(", ")
+ r.emit()
+ emit(")")
+
+def emit·copy256(l, r):
+ emit("_mm256_storeu_pd(")
+ l.emit()
+ emit(", ")
+ emit("_mm256_loadu_pd(")
+ r.emit()
+ emit("))")
+
+
+
# TODO: Typedefs...
# ------------------------------------------
@@ -239,9 +283,18 @@ class UnaryOp(Expr):
pass
class Deref(UnaryOp):
+ method = {
+ Ptr(Float64x4) : lambda x: emit·load256(None, x),
+ }
+
def emit(self):
- emit("*")
- self.x.emit()
+ kind = GetType(self.x)
+ if kind in self.method:
+ print("emitting")
+ self.method[kind](self.x)
+ else:
+ emit("*")
+ self.x.emit()
class Negate(UnaryOp):
def emit(self):
@@ -375,40 +428,19 @@ class Assign(Expr):
self.lhs = lhs
self.rhs = rhs
-def emit·load256(l, r):
- emit("_mm256_loadu_pd(")
- l.emit()
- emit(", ")
- r.emit()
- emit(")")
-
-def emit·store256(l, r):
- emit("_mm256_storeu_pd(")
- l.emit()
- emit(", ")
- r.emit()
- emit(")")
-
-def emit·copy256(l, r):
- emit("_mm256_storeu_pd(")
- l.emit()
- emit(", ")
- emit("_mm256_loadu_pd(")
- r.emit()
- emit("))")
-
-SetTypeDispatch = {
- (Float64x4, Ptr(Float64x4)) : emit·load256,
- (Ptr(Float64x4), Float64x4) : emit·store256,
- (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256,
-}
-
class Set(Assign):
+ method = {
+ (Float64x2, Ptr(Float64)) : emit·load128,
+ (Float64x4, Float64x2) : emit·broadcast256,
+ (Float64x4, Ptr(Float64x4)) : emit·load256,
+ (Ptr(Float64x4), Float64x4) : emit·store256,
+ (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256,
+ }
def emit(self):
lhs = GetType(self.lhs)
rhs = GetType(self.rhs)
- if (lhs, rhs) in SetTypeDispatch:
- SetTypeDispatch[(lhs, rhs)](self.lhs, self.rhs)
+ if (lhs, rhs) in self.method:
+ self.method[(lhs, rhs)](self.lhs, self.rhs)
else:
self.lhs.emit()
emit(f" = ")
@@ -416,9 +448,14 @@ class Set(Assign):
class AddSet(Assign):
def emit(self):
- self.lhs.emit()
- emit(f" += ")
- self.rhs.emit()
+ lhs = GetType(self.lhs)
+ rhs = GetType(self.rhs)
+ if (lhs, rhs) in Set.method:
+ Set(self.lhs, Add(self.lhs, self.rhs)).emit()
+ else:
+ self.lhs.emit()
+ emit(f" += ")
+ self.rhs.emit()
class SubSet(Assign):
def emit(self):
@@ -760,6 +797,14 @@ def _(x: Assign, i: int):
return type(x)(Step(x.lhs, i), Step(x.rhs, i))
@Step.register
+def _(x: BinaryOp, i: int):
+ return type(x)(Step(x.l, i), Step(x.r, i))
+
+@Step.register
+def _(x: UnaryOp, i: int):
+ return type(x)(Step(x.x, i))
+
+@Step.register
def _(x: Ident, i: int):
return x
@@ -834,6 +879,8 @@ def _(x: StmtExpr, func):
@singledispatch
def GetType(x: Expr) -> Type:
+ if x is None:
+ return
raise TypeError(f"{type(x)} not supported by GetType operation")
@GetType.register
@@ -955,7 +1002,7 @@ def _(x: Comma, func):
@Transform.register
def _(x: BinaryOp, func):
- return func(BinaryOp(Transform(x.l, func), Transform(x.r, func)))
+ return func(type(x)(Transform(x.l, func), Transform(x.r, func)))
# ------------------------------------------
# Filter function
@@ -1171,7 +1218,13 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
# Create symbol table
if i == 0:
for s in scalars[0]:
- s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
+ if type(s.var) == Param:
+ intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128"))
+ s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
+ vfunc.execute(Set(intermediate, Ref(s)))
+ vfunc.execute(Set(s_symtab[s.name], intermediate))
+ else:
+ s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
for v in vectors[0].keys():
v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256")
@@ -1183,16 +1236,23 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
# loop.execute(load)
# IMPORTANT: We do a post-order traversal.
- # The identifier will be substituted first!
+ # We transforms leaves (identifiers) first and then move back up the root
def translate(x):
- if type(x) == Ident:
+ if isinstance(x, Ident):
if x.name in s_symtab:
return s_symtab[x.name]
elif x.name in v_symtab:
x.var = v_symtab[x.name]
- if type(x) == Index:
+ if isinstance(x, Index):
if type(x.x) == Ident and x.x.name in v_symtab:
return Add(x.x, x.i)
+ if isinstance(x, BinaryOp):
+ l, r = GetType(x.l), GetType(x.r)
+ if IsVectorType(l) or IsVectorType(r):
+ if type(l) == Ptr:
+ x.l = Deref(x.l)
+ if type(r) == Ptr:
+ x.r = Deref(x.r)
# if type(x) == Index and x.x in set(symtab.values()):
# return x.x
@@ -1216,23 +1276,29 @@ def Strided(func: Func) -> Func:
if __name__ == "__main__":
emitheader()
- Rot = Func("blas·swap", Void,
+ F = Func("blas·axpy", Void,
Params(
- (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
+ (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
),
Vars(
- (Int, "i"), (Float64, "tmp")
+ (Int, "i"), #(Float64, "tmp")
)
)
# could also declare like
- # Rot.declare( ... )
+ # F.declare( ... )
# TODO: Increase ergonomics here...
- len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)],
- body = Swap(Index(x, I(0)), Index(y, I(0)), tmp)
- )
+ len, x, y, i, a = F.variables("len", "x", "y", "i", "a")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)])
+ # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp)
+ loop.execute(
+ Set(Index(y, I(0)),
+ Add(Index(y, I(0)),
+ Mul(a, Index(x, I(0)))
+ )
+ )
+ )
rem, kernel, call = Unroll(loop, 8, "swap")
kernel.emit()
@@ -1242,8 +1308,8 @@ if __name__ == "__main__":
avx256kernel.emit()
emitln(2)
- Rot.execute(Set(i, call), rem)
+ F.execute(Set(i, call), rem)
- Rot.emit()
+ F.emit()
print(buffer)