aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-12 18:38:59 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-12 18:38:59 -0700
commitd3241acc69327081c2f9c2b1d9ed4ae96d8f1287 (patch)
tree3c11708cc61d8d92d925ebb22a51b99a5e042d90
parent1ec3d68c86dafd02520edd62d954c833e28515e3 (diff)
feat: allow for cleanup of end of vectorization functions
-rw-r--r--lib/c.py94
1 files changed, 61 insertions, 33 deletions
diff --git a/lib/c.py b/lib/c.py
index 08203cb..762bd0b 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -689,6 +689,7 @@ class Func(Decl):
self.stmts[-1].emit()
exits_scope()
+ emitln(2)
def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]:
if var.name in [v.name for v in self.vars]:
@@ -978,7 +979,8 @@ def _(x: Index):
elif isinstance(base, Array):
return base.base
else:
- raise TypeError(f"attempting to index type {base}")
+ pass
+ # raise TypeError(f"attempting to index type {base} of node {x.x}")
# TODO: type checking for both
@@ -1206,13 +1208,12 @@ def AddrAccessed(stmt: Stmt):
if vars[0] in scalars:
scalars.remove(vars[0])
- print([(key.name, val) for key, val in vectors.items()])
return set(scalars), vectors
# ------------------------------------------
# Large scale functions
-def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call):
+def Unroll(name: str, loop: For, times: int, ret: Ident = None) -> (Func, List[Stmt]):
# TODO: More sophisticated computation for length of loop
if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT):
raise TypeError(f"{type(loop.cond)} not supported in loop unrolling")
@@ -1227,16 +1228,22 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
else:
return Var.copy(v)
- param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if type(v) == Param]
+ n = loop.cond.r.var
+ param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if (type(v) == Param and v != n)]
stack = {v.name: asvector(v) for v in vars if type(v) == Var}
# TODO: More sophisticated type checking
- n = loop.cond.r.var
if (type(n) != Param):
raise TypeError(f"{type(n)} not implemented yet")
- kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values()))
- body = loop.body
+ if ret is None:
+ kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values()))
+ else:
+ kernel = Func(f"{name}_kernel{times}", ret.var.type, param, list(stack.values()))
+
+ len = kernel.declare(n)
+
+ body = loop.body
itor, itorp = kernel.variables("i", "ip")
def mkarray(x: Expr, times: int):
@@ -1256,24 +1263,42 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
return x
+ expandedloop = Make(loop.body, lambda node: Transform(node, lambda x: mkarray(x, times)))
kernel.execute(
- Set(n, EvenTo(n, times)),
- For(Set(itor, I(0)), LT(itor, n), Repeat(loop.step, times),
+ Set(len, EvenTo(Deref(itorp), times)),
+ For(Set(itor, I(0)), LT(itor, len), Repeat(loop.step, times),
body = Block(*
- (Make(
- Make(loop.body,
- lambda node:
- Transform(node, lambda x: mkarray(x, times))),
- lambda node:
- Transform(node, lambda x: step(x, i))) for i in range(times)
+ (Make(expandedloop, lambda node:
+ Transform(node, lambda x: step(x, i))) for i in range(times)
)
)
),
Set(Deref(itorp), itor)
)
+ if ret is not None:
+ # NOTE: This is probably not general...
+ def index(node):
+ if node == ret:
+ return Index(node, I(0))
+ if node == itor:
+ return Index(ret, itor)
+ return node
+
+ k = kernel.variables(ret.name)
+ if k.var.type != ret.var.type:
+ resolve = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor),
+ body = Make(expandedloop, lambda node: Transform(node, index))
+ )
+ kernel.execute(resolve)
+
+ kernel.execute(Return(Index(ret, I(0))))
+
loop.init = None
- return loop, kernel, Call(kernel, param)
+ if ret is None:
+ return kernel, [Set(itor, len), Call(kernel, [Ref(itor)] + param[1:]), loop]
+ else:
+ return kernel, [Set(itor, len), Set(ret, Call(kernel, [Ref(itor)] + param[1:])), loop]
# Replaces all vectorizable loops inside Func with vectorized variants
# Returns the new vectorized function (tagged by the SIMD chosen)
@@ -1331,6 +1356,12 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
vecs = vectors[0]
if i == 0:
syms = SymTab()
+ for s in scalars[0]:
+ intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128"))
+ syms.stack[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
+ vfunc.execute(Set(intermediate, Ref(s)))
+ vfunc.execute(Set(syms.stack[s.name], intermediate))
+
for v in vecs.keys():
# All params are treated AS addresses to load into vectorized registers
if type(v.var) == Param:
@@ -1344,6 +1375,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
else:
syms.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256"))
+
# IMPORTANT: We do a post-order traversal.
# We transforms leaves (identifiers) first and then move back up the root
def translate(x):
@@ -1356,7 +1388,7 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
if type(x.x) == Ident:
if x.x.name in syms.addrs:
return Add(x.x, x.i)
- elif f"{x.x.name[:-3]}" in syms.stack:
+ elif f"{x.x.name[:-3]}" in syms.stack: #NOTE: This is hacky. Think of something better
return Index(x.x, I(Eval(Div(x.i, 4))))
if isinstance(x, BinaryOp):
l, r = GetType(x.l), GetType(x.r)
@@ -1402,17 +1434,14 @@ def copy():
Swap(Index(x, i), Index(y, i), tmp)
)
- rem, kernel, call = Unroll(loop, 8, "copy")
+ kernel, calls = Unroll("copy", loop, 8)
kernel.emit()
- emitln(2)
avx256kernel = Vectorize(kernel, SIMD.AVX2)
avx256kernel.emit()
- emitln(2)
- F.execute(Set(i, call), rem)
+ F.execute(*calls)
F.emit()
- print(buffer)
def axpby():
F = Func("blas·axpby", Void,
@@ -1440,17 +1469,14 @@ def axpby():
)
)
- rem, kernel, call = Unroll(loop, 8, "axpby")
+ kernel, calls = Unroll("axpby", loop, 8)
kernel.emit()
- emitln(2)
avx256kernel = Vectorize(kernel, SIMD.AVX2)
avx256kernel.emit()
- emitln(2)
- F.execute(Set(i, call), rem)
+ F.execute(*calls)
F.emit()
- print(buffer)
def argmax():
F = Func("blas·argmax", Int,
@@ -1473,19 +1499,21 @@ def argmax():
)
)
- loop, kernel, call = Unroll(loop, 8, "argmax")
+ kernel, calls = Unroll("argmax", loop, 8, ix)
kernel.emit()
- emitln(2)
- F.execute(loop)
+ avx256kernel = Vectorize(kernel, SIMD.AVX2)
+ avx256kernel.emit()
+
+ F.execute(*calls)
F.emit()
- print(buffer)
if __name__ == "__main__":
emitheader()
- copy()
+ # copy()
# axpby()
- # argmax()
+ argmax()
+ print(buffer)