from C import * import copy ROW = 4 COL = 4 def pkg(name): return f"blasĀ·{name}" def typeify(string, kind): if (kind == Float32): return f"{string}f" if (kind == Float64): return f"{string}d" # ------------------------------------------------------------------------ # Helpers (abandoning the automatic unroll from level 1) def toarray(len: int, *args): return [Param(Array(arg.type, len), arg.name) for arg in args] def TryIndex(x, i): if IsArrayType(x.var.type): return Index(x, i) return x def AddElts(root, *vars): for var in vars: root = Add(root ,var) return root def Round(store, number, by): return Set(store, And(number, Negate(I(by-1)))) def UnitIncs(root, *incs): root = EQ(root, I(1)) for inc in incs: root = AndAnd(root, EQ(inc, I(1))) return root def IsInc(p): return p.name != "incx" and p.name != "incy" def Identity(p): return True def FilterParams(params, func): return [p for p in params if func(p)] def StrideAllIndexedTerms(stmts, var, itor, inc): def is_hit(x): if isinstance(x, Index): if isinstance(x.i, BinaryOp) and x.x == var: return x.i.l == itor return False terms = [] for stmt in stmts: Visit(stmt, lambda node: Filter(node, is_hit, terms)) for term in terms: term.i = Mul(Paren(term.i), inc) def AsStrided(stmts, var, itor, inc): def increment(x): if isinstance(x, Index): if isinstance(x.i, BinaryOp) and x.x == var: return Index(x.x, Mul(Paren(x.i), inc)) return copy.copy(x) if isinstance(stmts, Block): return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts]) elif isinstance(stmts, list): return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts] else: raise TypeError("unrecognized stmts type") class Iter(object): def __init__(self, it, end, len, inc): self.it = it self.end = end self.len = len self.inc = inc def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]): def Step(it, inc): if inc == 1: return Inc(it) else: return AddSet(it, I(inc)) return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), For(None, LT(bot.it, bot.len), Inc(bot.it), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)]) ), *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True): def Step(it, inc): if inc == 1: return Inc(it) else: return AddSet(it, I(inc)) def Finish(j): if j == 0: return For(None, LE(bot.it, top.it), Inc(bot.it), Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) ) else: return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) def Start(j, end): if j == end: return For(None, LT(bot.it, bot.end), Inc(bot.it), Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)]) ) else: return Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) if upper: return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], Set(bot.end, Add(Paren(EvenTo(Paren(Sub(top.end, top.it)), bot.inc)), top.it)), Set(bot.it, top.it), *[ Start(j, top.inc-1) for j in range(top.inc) if bot.inc > 1], For(None, LT(bot.it, bot.len), Step(bot.it, bot.inc), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) else: return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], Set(bot.end, EvenTo(top.it, bot.inc)), For(Set(bot.it, I(0)), LE(bot.it, bot.end), Step(bot.it, bot.inc), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), *[ Finish(j) for j in range(top.inc) if bot.inc > 1], *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) def ToKernel(name, loop): vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \ VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body) # def ExpandAdd(i: int, c: Emitter, inc: int): # offset = Add(c, I(0)) # root = Mul(Index(Index(row, I(i)), offset), Index(x, offset)) # for n in range(1, inc): # offset = Add(c, I(n)) # root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset))) # return root # ------------------------------------------------------------------------ # Blas level 2 functions def trsv(kind): name = typeify("trsv", kind) F = Func(pkg(name), Void, Params( (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx") ), Vars( (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") ) ) r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") incx = F.variables("incx") rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) template = lambda inc_r, inc_c: TriangularLoop(rows(inc_r), cols(inc_c), Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], upper = True ) loop = template(1, 1) loop.emit() def syr(kind): name = typeify("syr", kind) F = Func(pkg(name), Void, Params( (UInt32, "flag"), (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"), ), Vars( (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") ) ) r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") incx = F.variables("incx") rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], upper = upper == "upper" ) blocks = [] for layout in ["lower", "upper"]: floop = template(1, 1, layout) sloop = template(1, 1, layout) sloop.body = AsStrided(sloop.body, x, c, incx) fini = template(1, 1, layout) fini.init = None fini.body = AsStrided(fini.body, x, c, incx) fini.cond = LT(r, _len) blocks.append( Block( If(UnitIncs(incx), Block(floop), Block(sloop)), fini, Return(), ) ) F.execute(If(flag, blocks[0], blocks[1])) F.emit() def ger(kind): name = typeify("ger", kind) F = Func(pkg(name), Void, Params( (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"), ), Vars( (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") ) ) r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y") incx, incy = F.variables("incx", "incy") rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc) template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc), Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)], Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], ) # loop = template(1, 1) # F.execute(loop) # F.emit() floop = template(ROW, COL) sloop = template(ROW, COL) sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy) fini = template(1, 2*COL) fini.init = None fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy) fini.cond = LT(r, nrow) F.execute( Set(nr, EvenTo(nrow, ROW)), Set(nc, EvenTo(ncol, COL)), If(UnitIncs(incx, incy), Block(floop), Block(sloop)) ) F.execute(fini) F.emit() def gemv(kind): name = typeify("gemv", kind) params = Params( (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy") ) stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res")) F = Func(pkg(name), Void, params, stack) # --------------------- # Kernel def innerloop(rinc, cit, cend, cinc): return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)), Block(*[AddSet(TryIndex(res, I(i)), AddElts(*(Mul( Index(TryIndex(row, I(i)), Add(cit, I(j))), Index(x, Add(cit, I(j)))) for j in range(cinc) ) ) ) for i in range(rinc)]) ) def tryloop(rinc, cit, cend, cinc): if cinc > 1: loop = innerloop(rinc, cit, cend, 1) loop.init = None return loop def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res): return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)), Block( *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)], *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)], innerloop(rinc, cit, cend, cinc), tryloop(rinc, cit, clen, cinc), *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)] ) ) kerns = [] for func, sfx in [(IsInc, ""), (Identity, "_s")]: kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True) r, c, row, res = kern.variables("r", "c", "row", "res") nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") ncolr = kern.declare(Var(Int, "ncolr")) loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res) kern.execute(Round(ncolr, ncol, COL)) kern.execute(loop) if "_s" in sfx: incx, incy = kern.variables("incx", "incy") StrideAllIndexedTerms(kern.stmts, x, c, incx) StrideAllIndexedTerms(kern.stmts, y, r, incy) kern.emit() kerns.append(kern) r, c, row, res = F.variables("r", "c", "row", "res") nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") incx, incy = F.variables("incx", "incy") F.execute(Round(r, nrow, ROW)) F.execute( If(UnitIncs(incx, incy), Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])), Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])), ) ) F.params = Params((UInt32, "flag")) + F.params remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res) remainder.init = None F.execute(remainder) StrideAllIndexedTerms(F, x, c, incx) StrideAllIndexedTerms(F, y, r, incy) F.emit() # ------------------------------------------------------------------------ # Code Generation if __name__ == "__main__": emit("#include \n") emit("#include \n") emit("#include \n") emitln() emit("/*********************************************************/\n") emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n") emit("/*********************************************************/\n") emitln(2) for kind in [Float64]: #[Float32, Float64]: trsv(kind) # syr(kind) # ger(kind) # gemv(kind) flush()