aboutsummaryrefslogtreecommitdiff
path: root/lib/c.py
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-13 08:29:16 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-13 08:29:16 -0700
commitc9d4b2d7dd1d9a46571e5d2b2cf6ce10a9d9ebea (patch)
treead3c1cf1d3295760e7c32d6cdd17846febf1dbea /lib/c.py
parentd3241acc69327081c2f9c2b1d9ed4ae96d8f1287 (diff)
unrolling blas level 1 fully works
Diffstat (limited to 'lib/c.py')
-rw-r--r--lib/c.py98
1 files changed, 79 insertions, 19 deletions
diff --git a/lib/c.py b/lib/c.py
index 762bd0b..bbc9ce3 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -1213,7 +1213,7 @@ def AddrAccessed(stmt: Stmt):
# ------------------------------------------
# Large scale functions
-def Unroll(name: str, loop: For, times: int, ret: Ident = None) -> (Func, List[Stmt]):
+def Unroll(name: str, loop: For, times: int, ret: Ident = None, accumulator = None) -> (Func, List[Stmt]):
# TODO: More sophisticated computation for length of loop
if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT):
raise TypeError(f"{type(loop.cond)} not supported in loop unrolling")
@@ -1277,20 +1277,15 @@ def Unroll(name: str, loop: For, times: int, ret: Ident = None) -> (Func, List[S
)
if ret is not None:
- # NOTE: This is probably not general...
- def index(node):
- if node == ret:
- return Index(node, I(0))
- if node == itor:
- return Index(ret, itor)
- return node
-
k = kernel.variables(ret.name)
if k.var.type != ret.var.type:
- resolve = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor),
- body = Make(expandedloop, lambda node: Transform(node, index))
- )
- kernel.execute(resolve)
+ if accumulator is None:
+ raise ValueError("If loop returns a value, an accumulator must be given")
+ else:
+ accumulator = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor),
+ body = accumulator(I(0), itor, k)
+ )
+ kernel.execute(accumulator)
kernel.execute(Return(Index(ret, I(0))))
@@ -1323,8 +1318,10 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
body = stmt.body
if type(body) != Block:
- print("could not vectorize loop, skipping")
- loop.body = body
+ # TODO: Think through this carefully...
+ # This is coded for the accumulation step at the end of a kernel
+ loop.cond.r = I(Eval(Div(loop.cond.r, 4)))
+ loop.body = body
else:
instr = body.stmts
# TODO: Remove hardcoded 4 -> should be function of types!
@@ -1494,26 +1491,89 @@ def argmax():
If(GT(Index(x, i), max),
Block(
Set(ix, i),
- Set(max, Index(x, i)),
+ Set(max, Index(x, ix)),
)
)
)
+ kernel, calls = Unroll("argmax", loop, 8, ix,
+ lambda ires, icur, node:
+ If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))),
+ Block(Set(Index(node, ires), Index(node, icur)))
+ )
+ )
+ kernel.emit()
+
+ # avx256kernel = Vectorize(kernel, SIMD.AVX2)
+ # avx256kernel.emit()
+
+ F.execute(*calls, Return(ix))
+
+ F.execute(loop)
+ F.emit()
+
+def dot():
+ F = Func("blas·dot", Float64,
+ Params(
+ (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
+ ),
+ Vars(
+ (Int, "i"), (Float64, "sum"),
+ )
+ )
+ len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ AddSet(sum, Mul(Index(x, i), Index(y, i)))
+ )
- kernel, calls = Unroll("argmax", loop, 8, ix)
+ kernel, calls = Unroll("dot", loop, 16, sum,
+ lambda ires, icur, node:
+ StmtExpr(AddSet(Index(node, ires), Index(node, icur)))
+ )
kernel.emit()
avx256kernel = Vectorize(kernel, SIMD.AVX2)
avx256kernel.emit()
- F.execute(*calls)
+ F.execute(*calls, Return(sum))
F.emit()
+def gemv():
+ F = Func("blas·gemv", Void,
+ Params(
+ (Int, "nrow"), (Int, "ncol"), (Float64, "a"), (Ptr(Float64), "m"), (Int, "incm"),
+ (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y")
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Ptr(Float64), "row"), (Float64, "res")
+ )
+ )
+ r, c, row, res = F.variables("r", "c", "row", "res")
+ nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
+ loop = For(Set(r, I(0)), LT(r, nrow), Inc(r),
+ Block(
+ Set(row, Add(m, incm)),
+ Set(res, I(0)),
+ For(Set(c, I(0)), LT(c, ncol), Inc(c),
+ AddSet(res, Mul(Index(row, c), Index(x, c)))
+ ),
+ Set(Index(y, r), Add(Mul(a, res), Mul(b, Index(y, r))))
+ )
+ )
+
+ F.execute(loop)
+ F.emit()
+
+
if __name__ == "__main__":
emitheader()
+ gemv()
+ # dot()
+ # argmax()
# copy()
# axpby()
- argmax()
print(buffer)