aboutsummaryrefslogtreecommitdiff
path: root/sys/libmath/gen2.py
diff options
context:
space:
mode:
Diffstat (limited to 'sys/libmath/gen2.py')
-rwxr-xr-xsys/libmath/gen2.py390
1 files changed, 390 insertions, 0 deletions
diff --git a/sys/libmath/gen2.py b/sys/libmath/gen2.py
new file mode 100755
index 0000000..6ce2a12
--- /dev/null
+++ b/sys/libmath/gen2.py
@@ -0,0 +1,390 @@
+from C import *
+import copy
+
+ROW = 4
+COL = 4
+
+def pkg(name):
+ return f"blasĀ·{name}"
+
+def typeify(string, kind):
+ if (kind == Float32):
+ return f"{string}f"
+ if (kind == Float64):
+ return f"{string}d"
+
+# ------------------------------------------------------------------------
+# Helpers (abandoning the automatic unroll from level 1)
+
+def toarray(len: int, *args):
+ return [Param(Array(arg.type, len), arg.name) for arg in args]
+
+def TryIndex(x, i):
+ if IsArrayType(x.var.type):
+ return Index(x, i)
+ return x
+
+def AddElts(root, *vars):
+ for var in vars:
+ root = Add(root ,var)
+ return root
+
+def Round(store, number, by):
+ return Set(store, And(number, Negate(I(by-1))))
+
+def UnitIncs(root, *incs):
+ root = EQ(root, I(1))
+ for inc in incs:
+ root = AndAnd(root, EQ(inc, I(1)))
+ return root
+
+def IsInc(p):
+ return p.name != "incx" and p.name != "incy"
+
+def Identity(p):
+ return True
+
+def FilterParams(params, func):
+ return [p for p in params if func(p)]
+
+def StrideAllIndexedTerms(stmts, var, itor, inc):
+ def is_hit(x):
+ if isinstance(x, Index):
+ if isinstance(x.i, BinaryOp) and x.x == var:
+ return x.i.l == itor
+
+ return False
+
+ terms = []
+ for stmt in stmts:
+ Visit(stmt, lambda node: Filter(node, is_hit, terms))
+
+ for term in terms:
+ term.i = Mul(Paren(term.i), inc)
+
+def AsStrided(stmts, var, itor, inc):
+ def increment(x):
+ if isinstance(x, Index):
+ if isinstance(x.i, BinaryOp) and x.x == var:
+ return Index(x.x, Mul(Paren(x.i), inc))
+
+ return copy.copy(x)
+
+ if isinstance(stmts, Block):
+ return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts])
+ elif isinstance(stmts, list):
+ return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts]
+ else:
+ raise TypeError("unrecognized stmts type")
+
+class Iter(object):
+ def __init__(self, it, end, len, inc):
+ self.it = it
+ self.end = end
+ self.len = len
+ self.inc = inc
+
+def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]):
+ def Step(it, inc):
+ if inc == 1:
+ return Inc(it)
+ else:
+ return AddSet(it, I(inc))
+
+ return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
+ Block(*[
+ *[func(i) for func in Preamble for i in range(top.inc)],
+ For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
+ ),
+ For(None, LT(bot.it, bot.len), Inc(bot.it),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)])
+ ),
+ *[func(i) for func in Postamble for i in range(bot.inc)]
+ ])
+ )
+
+def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True):
+ def Step(it, inc):
+ if inc == 1:
+ return Inc(it)
+ else:
+ return AddSet(it, I(inc))
+
+ def Finish(j):
+ if j == 0:
+ return For(None, LE(bot.it, top.it), Inc(bot.it),
+ Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)])
+ )
+ else:
+ return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it))
+
+ def Start(j, end):
+ if j == end:
+ return For(None, LT(bot.it, bot.end), Inc(bot.it),
+ Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)])
+ )
+ else:
+ return Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it))
+
+ if upper:
+ return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
+ Block(*[
+ *[func(i) for func in Preamble for i in range(top.inc)],
+ Set(bot.end, Add(Paren(EvenTo(Paren(Sub(top.end, top.it)), bot.inc)), top.it)),
+ Set(bot.it, top.it),
+ *[ Start(j, top.inc-1) for j in range(top.inc) if bot.inc > 1],
+ For(None, LT(bot.it, bot.len), Step(bot.it, bot.inc),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
+ ),
+ *[func(i) for func in Postamble for i in range(bot.inc)]
+ ])
+ )
+ else:
+ return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
+ Block(*[
+ *[func(i) for func in Preamble for i in range(top.inc)],
+ Set(bot.end, EvenTo(top.it, bot.inc)),
+ For(Set(bot.it, I(0)), LE(bot.it, bot.end), Step(bot.it, bot.inc),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
+ ),
+ *[ Finish(j) for j in range(top.inc) if bot.inc > 1],
+ *[func(i) for func in Postamble for i in range(bot.inc)]
+ ])
+ )
+
+
+
+def ToKernel(name, loop):
+ vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \
+ VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body)
+
+# def ExpandAdd(i: int, c: Emitter, inc: int):
+# offset = Add(c, I(0))
+# root = Mul(Index(Index(row, I(i)), offset), Index(x, offset))
+# for n in range(1, inc):
+# offset = Add(c, I(n))
+# root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset)))
+# return root
+
+# ------------------------------------------------------------------------
+# Blas level 2 functions
+
+def trsv(kind):
+ name = typeify("trsv", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx")
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
+ )
+ )
+
+ r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
+ flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x")
+ incx = F.variables("incx")
+
+ rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c)
+
+ template = lambda inc_r, inc_c: TriangularLoop(rows(inc_r), cols(inc_c),
+ Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)],
+ Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
+ lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
+ upper = True
+ )
+
+ loop = template(1, 1)
+ loop.emit()
+
+def syr(kind):
+ name = typeify("syr", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (UInt32, "flag"), (Int, "len"), (kind, "a"),
+ (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"),
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
+ )
+ )
+
+ r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
+ flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x")
+ incx = F.variables("incx")
+
+ rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c)
+
+ template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c),
+ Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)],
+ Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
+ lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
+ upper = upper == "upper"
+ )
+
+ blocks = []
+ for layout in ["lower", "upper"]:
+ floop = template(1, 1, layout)
+ sloop = template(1, 1, layout)
+ sloop.body = AsStrided(sloop.body, x, c, incx)
+
+ fini = template(1, 1, layout)
+ fini.init = None
+ fini.body = AsStrided(fini.body, x, c, incx)
+ fini.cond = LT(r, _len)
+
+ blocks.append(
+ Block(
+ If(UnitIncs(incx), Block(floop), Block(sloop)),
+ fini,
+ Return(),
+ )
+ )
+ F.execute(If(flag, blocks[0], blocks[1]))
+ F.emit()
+
+def ger(kind):
+ name = typeify("ger", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "nrow"), (Int, "ncol"), (kind, "a"),
+ (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"),
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
+ )
+ )
+
+ r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
+ nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y")
+ incx, incy = F.variables("incx", "incy")
+
+ rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc)
+
+ template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc),
+ Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)],
+ Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
+ lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
+ )
+
+ # loop = template(1, 1)
+ # F.execute(loop)
+ # F.emit()
+ floop = template(ROW, COL)
+ sloop = template(ROW, COL)
+ sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy)
+
+ fini = template(1, 2*COL)
+ fini.init = None
+ fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy)
+ fini.cond = LT(r, nrow)
+
+ F.execute(
+ Set(nr, EvenTo(nrow, ROW)),
+ Set(nc, EvenTo(ncol, COL)),
+ If(UnitIncs(incx, incy), Block(floop), Block(sloop))
+ )
+ F.execute(fini)
+ F.emit()
+
+def gemv(kind):
+ name = typeify("gemv", kind)
+ params = Params(
+ (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"),
+ (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy")
+ )
+ stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res"))
+ F = Func(pkg(name), Void, params, stack)
+
+ # ---------------------
+ # Kernel
+
+ def innerloop(rinc, cit, cend, cinc):
+ return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)),
+ Block(*[AddSet(TryIndex(res, I(i)),
+ AddElts(*(Mul(
+ Index(TryIndex(row, I(i)), Add(cit, I(j))),
+ Index(x, Add(cit, I(j)))) for j in range(cinc)
+ )
+ )
+ ) for i in range(rinc)])
+ )
+
+ def tryloop(rinc, cit, cend, cinc):
+ if cinc > 1:
+ loop = innerloop(rinc, cit, cend, 1)
+ loop.init = None
+ return loop
+
+ def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res):
+ return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)),
+ Block(
+ *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)],
+ *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)],
+ innerloop(rinc, cit, cend, cinc),
+ tryloop(rinc, cit, clen, cinc),
+ *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)]
+ )
+ )
+
+ kerns = []
+ for func, sfx in [(IsInc, ""), (Identity, "_s")]:
+ kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True)
+ r, c, row, res = kern.variables("r", "c", "row", "res")
+ nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
+
+ ncolr = kern.declare(Var(Int, "ncolr"))
+ loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res)
+
+ kern.execute(Round(ncolr, ncol, COL))
+ kern.execute(loop)
+ if "_s" in sfx:
+ incx, incy = kern.variables("incx", "incy")
+ StrideAllIndexedTerms(kern.stmts, x, c, incx)
+ StrideAllIndexedTerms(kern.stmts, y, r, incy)
+
+ kern.emit()
+
+ kerns.append(kern)
+
+ r, c, row, res = F.variables("r", "c", "row", "res")
+ nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
+ incx, incy = F.variables("incx", "incy")
+ F.execute(Round(r, nrow, ROW))
+ F.execute(
+ If(UnitIncs(incx, incy),
+ Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])),
+ Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])),
+ )
+ )
+
+ F.params = Params((UInt32, "flag")) + F.params
+
+ remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res)
+ remainder.init = None
+ F.execute(remainder)
+ StrideAllIndexedTerms(F, x, c, incx)
+ StrideAllIndexedTerms(F, y, r, incy)
+
+ F.emit()
+
+# ------------------------------------------------------------------------
+# Code Generation
+
+if __name__ == "__main__":
+ emit("#include <u.h>\n")
+ emit("#include <libn.h>\n")
+ emit("#include <libmath.h>\n")
+ emitln()
+ emit("/*********************************************************/\n")
+ emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n")
+ emit("/*********************************************************/\n")
+ emitln(2)
+
+ for kind in [Float64]: #[Float32, Float64]:
+ trsv(kind)
+ # syr(kind)
+ # ger(kind)
+ # gemv(kind)
+
+ flush()