diff options
Diffstat (limited to 'sys/libmath/gen2.py')
-rwxr-xr-x | sys/libmath/gen2.py | 390 |
1 files changed, 390 insertions, 0 deletions
diff --git a/sys/libmath/gen2.py b/sys/libmath/gen2.py new file mode 100755 index 0000000..6ce2a12 --- /dev/null +++ b/sys/libmath/gen2.py @@ -0,0 +1,390 @@ +from C import * +import copy + +ROW = 4 +COL = 4 + +def pkg(name): + return f"blasĀ·{name}" + +def typeify(string, kind): + if (kind == Float32): + return f"{string}f" + if (kind == Float64): + return f"{string}d" + +# ------------------------------------------------------------------------ +# Helpers (abandoning the automatic unroll from level 1) + +def toarray(len: int, *args): + return [Param(Array(arg.type, len), arg.name) for arg in args] + +def TryIndex(x, i): + if IsArrayType(x.var.type): + return Index(x, i) + return x + +def AddElts(root, *vars): + for var in vars: + root = Add(root ,var) + return root + +def Round(store, number, by): + return Set(store, And(number, Negate(I(by-1)))) + +def UnitIncs(root, *incs): + root = EQ(root, I(1)) + for inc in incs: + root = AndAnd(root, EQ(inc, I(1))) + return root + +def IsInc(p): + return p.name != "incx" and p.name != "incy" + +def Identity(p): + return True + +def FilterParams(params, func): + return [p for p in params if func(p)] + +def StrideAllIndexedTerms(stmts, var, itor, inc): + def is_hit(x): + if isinstance(x, Index): + if isinstance(x.i, BinaryOp) and x.x == var: + return x.i.l == itor + + return False + + terms = [] + for stmt in stmts: + Visit(stmt, lambda node: Filter(node, is_hit, terms)) + + for term in terms: + term.i = Mul(Paren(term.i), inc) + +def AsStrided(stmts, var, itor, inc): + def increment(x): + if isinstance(x, Index): + if isinstance(x.i, BinaryOp) and x.x == var: + return Index(x.x, Mul(Paren(x.i), inc)) + + return copy.copy(x) + + if isinstance(stmts, Block): + return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts]) + elif isinstance(stmts, list): + return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts] + else: + raise TypeError("unrecognized stmts type") + +class Iter(object): + def __init__(self, it, end, len, inc): + self.it = it + self.end = end + self.len = len + self.inc = inc + +def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]): + def Step(it, inc): + if inc == 1: + return Inc(it) + else: + return AddSet(it, I(inc)) + + return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + Block(*[ + *[func(i) for func in Preamble for i in range(top.inc)], + For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) + ), + For(None, LT(bot.it, bot.len), Inc(bot.it), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)]) + ), + *[func(i) for func in Postamble for i in range(bot.inc)] + ]) + ) + +def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True): + def Step(it, inc): + if inc == 1: + return Inc(it) + else: + return AddSet(it, I(inc)) + + def Finish(j): + if j == 0: + return For(None, LE(bot.it, top.it), Inc(bot.it), + Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) + ) + else: + return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) + + def Start(j, end): + if j == end: + return For(None, LT(bot.it, bot.end), Inc(bot.it), + Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)]) + ) + else: + return Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) + + if upper: + return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + Block(*[ + *[func(i) for func in Preamble for i in range(top.inc)], + Set(bot.end, Add(Paren(EvenTo(Paren(Sub(top.end, top.it)), bot.inc)), top.it)), + Set(bot.it, top.it), + *[ Start(j, top.inc-1) for j in range(top.inc) if bot.inc > 1], + For(None, LT(bot.it, bot.len), Step(bot.it, bot.inc), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) + ), + *[func(i) for func in Postamble for i in range(bot.inc)] + ]) + ) + else: + return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + Block(*[ + *[func(i) for func in Preamble for i in range(top.inc)], + Set(bot.end, EvenTo(top.it, bot.inc)), + For(Set(bot.it, I(0)), LE(bot.it, bot.end), Step(bot.it, bot.inc), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) + ), + *[ Finish(j) for j in range(top.inc) if bot.inc > 1], + *[func(i) for func in Postamble for i in range(bot.inc)] + ]) + ) + + + +def ToKernel(name, loop): + vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \ + VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body) + +# def ExpandAdd(i: int, c: Emitter, inc: int): +# offset = Add(c, I(0)) +# root = Mul(Index(Index(row, I(i)), offset), Index(x, offset)) +# for n in range(1, inc): +# offset = Add(c, I(n)) +# root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset))) +# return root + +# ------------------------------------------------------------------------ +# Blas level 2 functions + +def trsv(kind): + name = typeify("trsv", kind) + F = Func(pkg(name), Void, + Params( + (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx") + ), + Vars( + (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") + ) + ) + + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") + incx = F.variables("incx") + + rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) + + template = lambda inc_r, inc_c: TriangularLoop(rows(inc_r), cols(inc_c), + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + upper = True + ) + + loop = template(1, 1) + loop.emit() + +def syr(kind): + name = typeify("syr", kind) + F = Func(pkg(name), Void, + Params( + (UInt32, "flag"), (Int, "len"), (kind, "a"), + (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"), + ), + Vars( + (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") + ) + ) + + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") + incx = F.variables("incx") + + rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) + + template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + upper = upper == "upper" + ) + + blocks = [] + for layout in ["lower", "upper"]: + floop = template(1, 1, layout) + sloop = template(1, 1, layout) + sloop.body = AsStrided(sloop.body, x, c, incx) + + fini = template(1, 1, layout) + fini.init = None + fini.body = AsStrided(fini.body, x, c, incx) + fini.cond = LT(r, _len) + + blocks.append( + Block( + If(UnitIncs(incx), Block(floop), Block(sloop)), + fini, + Return(), + ) + ) + F.execute(If(flag, blocks[0], blocks[1])) + F.emit() + +def ger(kind): + name = typeify("ger", kind) + F = Func(pkg(name), Void, + Params( + (Int, "nrow"), (Int, "ncol"), (kind, "a"), + (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"), + ), + Vars( + (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") + ) + ) + + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y") + incx, incy = F.variables("incx", "incy") + + rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc) + + template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc), + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + ) + + # loop = template(1, 1) + # F.execute(loop) + # F.emit() + floop = template(ROW, COL) + sloop = template(ROW, COL) + sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy) + + fini = template(1, 2*COL) + fini.init = None + fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy) + fini.cond = LT(r, nrow) + + F.execute( + Set(nr, EvenTo(nrow, ROW)), + Set(nc, EvenTo(ncol, COL)), + If(UnitIncs(incx, incy), Block(floop), Block(sloop)) + ) + F.execute(fini) + F.emit() + +def gemv(kind): + name = typeify("gemv", kind) + params = Params( + (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"), + (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy") + ) + stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res")) + F = Func(pkg(name), Void, params, stack) + + # --------------------- + # Kernel + + def innerloop(rinc, cit, cend, cinc): + return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)), + Block(*[AddSet(TryIndex(res, I(i)), + AddElts(*(Mul( + Index(TryIndex(row, I(i)), Add(cit, I(j))), + Index(x, Add(cit, I(j)))) for j in range(cinc) + ) + ) + ) for i in range(rinc)]) + ) + + def tryloop(rinc, cit, cend, cinc): + if cinc > 1: + loop = innerloop(rinc, cit, cend, 1) + loop.init = None + return loop + + def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res): + return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)), + Block( + *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)], + *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)], + innerloop(rinc, cit, cend, cinc), + tryloop(rinc, cit, clen, cinc), + *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)] + ) + ) + + kerns = [] + for func, sfx in [(IsInc, ""), (Identity, "_s")]: + kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True) + r, c, row, res = kern.variables("r", "c", "row", "res") + nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") + + ncolr = kern.declare(Var(Int, "ncolr")) + loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res) + + kern.execute(Round(ncolr, ncol, COL)) + kern.execute(loop) + if "_s" in sfx: + incx, incy = kern.variables("incx", "incy") + StrideAllIndexedTerms(kern.stmts, x, c, incx) + StrideAllIndexedTerms(kern.stmts, y, r, incy) + + kern.emit() + + kerns.append(kern) + + r, c, row, res = F.variables("r", "c", "row", "res") + nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") + incx, incy = F.variables("incx", "incy") + F.execute(Round(r, nrow, ROW)) + F.execute( + If(UnitIncs(incx, incy), + Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])), + Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])), + ) + ) + + F.params = Params((UInt32, "flag")) + F.params + + remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res) + remainder.init = None + F.execute(remainder) + StrideAllIndexedTerms(F, x, c, incx) + StrideAllIndexedTerms(F, y, r, incy) + + F.emit() + +# ------------------------------------------------------------------------ +# Code Generation + +if __name__ == "__main__": + emit("#include <u.h>\n") + emit("#include <libn.h>\n") + emit("#include <libmath.h>\n") + emitln() + emit("/*********************************************************/\n") + emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n") + emit("/*********************************************************/\n") + emitln(2) + + for kind in [Float64]: #[Float32, Float64]: + trsv(kind) + # syr(kind) + # ger(kind) + # gemv(kind) + + flush() |