from C import * import copy ROW = 4 COL = 4 def pkg(name): return f"blasĀ·{name}" def typeify(string, kind): if (kind == Float32): return f"{string}f" if (kind == Float64): return f"{string}d" # ------------------------------------------------------------------------ # Helpers (abandoning the automatic unroll from level 1) def toarray(len: int, *args): return [Param(Array(arg.type, len), arg.name) for arg in args] def TryIndex(x, i): if IsArrayType(x.var.type): return Index(x, i) return x def AddElts(root, *vars): for var in vars: root = Add(root ,var) return root def Round(store, number, by): return Set(store, And(number, Negate(I(by-1)))) def UnitIncs(root, *incs): root = EQ(root, I(1)) for inc in incs: root = AndAnd(root, EQ(inc, I(1))) return root def IsInc(p): return p.name != "incx" and p.name != "incy" def Identity(p): return True def FilterParams(params, func): return [p for p in params if func(p)] def StrideAllIndexedTerms(stmts, var, itor, inc): def is_hit(x): if isinstance(x, Index): if isinstance(x.i, BinaryOp) and x.x == var: return x.i.l == itor return False terms = [] for stmt in stmts: Visit(stmt, lambda node: Filter(node, is_hit, terms)) for term in terms: term.i = Mul(Paren(term.i), inc) def AsStrided(stmts, var, itor, inc): def increment(x): if isinstance(x, Index): if isinstance(x.i, BinaryOp) and x.x == var: return Index(x.x, Mul(Paren(x.i), inc)) return copy.copy(x) if isinstance(stmts, Block): return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts]) elif isinstance(stmts, list): return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts] else: raise TypeError("unrecognized stmts type") class Iter(object): def __init__(self, it, end, len, inc): self.it = it self.end = end self.len = len self.inc = inc def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]): def Step(it, inc): if inc == 1: return Inc(it) else: return AddSet(it, I(inc)) return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), For(None, LT(bot.it, bot.len), Inc(bot.it), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)]) ), *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=lambda i: None, upper=True): # --------------------------------------------------------------- # Helper functions def Finish(j, upper=False): if upper: if j == 0: return Block(For(None, GT(bot.it, top.it), Dec(bot.it), Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) ), Postamble(j) ) return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Dec(bot.it), Postamble(j)) if not upper: if j == 0: return Block(For(None, LT(bot.it, top.it), Inc(bot.it), Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) ), Postamble(j)) return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it), Postamble(j)) def Step(it, inc, down=False): if not down: if inc == 1: return Inc(it) else: return AddSet(it, I(inc)) else: if inc == 1: return Dec(it) else: return SubSet(it, I(inc)) # --------------------------------------------------------------- # Main body if upper: return For(Set(top.it, Sub(top.end, I(1))), GE(top.it, I(0)), Step(top.it, top.inc, down=True), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], Set(bot.end, Sub(Paren(EvenTo(Paren(Sub(top.it, Paren(Sub(bot.len, I(1))))), bot.inc)), Paren(Sub(bot.len, I(1))))), For(Set(bot.it, Sub(bot.len, I(1))), GT(bot.it, bot.end), Step(bot.it, bot.inc, down=True), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), *[ Finish(j, upper) if bot.inc > 1 else Postamble(j) for j in range(top.inc) ], # *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) else: return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], Set(bot.end, EvenTo(top.it, bot.inc)), For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), *[ Finish(j) if bot.inc > 1 else Postamble(j) for j in range(top.inc) ], # *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) def Shift(x, i, uplo): if uplo == "upper": return Sub(x, I(i)) if uplo == "lower": return Add(x, I(i)) raise ValueError("unrecognized value") def ToKernel(name, loop): vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \ VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body) def ExpandAdd(row, x, i: int, c: Emitter, inc: int, uplo): offset = Add(c, I(0)) root = Mul(Index(Index(row, I(i)), offset), Index(x, offset)) for n in range(1, inc): offset = Shift(c, I(n), uplo) root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset))) return root # ------------------------------------------------------------------------ # Blas level 2 functions def trsv(kind): name = typeify("trsv", kind) F = Func(pkg(name), Void, Params( (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx") ), Vars( (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") ) ) r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") flag, _len, m, incm, x = F.variables("flag", "len", "m", "incm", "x") incx = F.variables("incx") rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), Kernel = lambda r, c, i, inc: [AddSet(Index(res, I(i)), ExpandAdd(row, x, i, c, inc, upper))], Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Shift(r, i, upper)), incm))), lambda i: Set(Index(res, I(i)), I(0))], Postamble = lambda i: Set(Index(x, Shift(c, I(0), upper)), Div(Index(res, I(i)), Index(Index(row, I(i)), Shift(c, I(0), upper)))), upper = upper == "upper" ) loop = template(1, 1, "upper") loop.emit() def syr(kind): name = typeify("syr", kind) F = Func(pkg(name), Void, Params( (UInt32, "flag"), (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"), ), Vars( (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") ) ) r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") incx = F.variables("incx") rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Shift(c, j, upper)), Mul(Index(res, I(i)), Index(x, Shift(c, j, upper)))) for j in range(inc)], Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Shift(r, i, upper)), incm))), lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Shift(r, i, upper))))], upper = upper == "upper" ) blocks = [] for layout in ["lower", "upper"]: floop = template(2, 2, layout) sloop = template(2, 2, layout) sloop.body = AsStrided(sloop.body, x, c, incx) fini = template(1, 1, layout) fini.init = None fini.body = AsStrided(fini.body, x, c, incx) fini.cond = LT(r, _len) blocks.append( Block( If(UnitIncs(incx), Block(floop), Block(sloop)), fini, Return(), ) ) F.execute( Set(nr, EvenTo(_len, ROW)), If(flag, blocks[0], blocks[1]) ) F.emit() def ger(kind): name = typeify("ger", kind) F = Func(pkg(name), Void, Params( (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"), ), Vars( (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") ) ) r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y") incx, incy = F.variables("incx", "incy") rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc) template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc), Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)], Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], ) # loop = template(1, 1) # F.execute(loop) # F.emit() floop = template(ROW, COL) sloop = template(ROW, COL) sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy) fini = template(1, 2*COL) fini.init = None fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy) fini.cond = LT(r, nrow) F.execute( Set(nr, EvenTo(nrow, ROW)), Set(nc, EvenTo(ncol, COL)), If(UnitIncs(incx, incy), Block(floop), Block(sloop)) ) F.execute(fini) F.emit() def gemv(kind): name = typeify("gemv", kind) params = Params( (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy") ) stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res")) F = Func(pkg(name), Void, params, stack) # --------------------- # Kernel def innerloop(rinc, cit, cend, cinc): return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)), Block(*[AddSet(TryIndex(res, I(i)), AddElts(*(Mul( Index(TryIndex(row, I(i)), Add(cit, I(j))), Index(x, Add(cit, I(j)))) for j in range(cinc) ) ) ) for i in range(rinc)]) ) def tryloop(rinc, cit, cend, cinc): if cinc > 1: loop = innerloop(rinc, cit, cend, 1) loop.init = None return loop def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res): return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)), Block( *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)], *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)], innerloop(rinc, cit, cend, cinc), tryloop(rinc, cit, clen, cinc), *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)] ) ) kerns = [] for func, sfx in [(IsInc, ""), (Identity, "_s")]: kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True) r, c, row, res = kern.variables("r", "c", "row", "res") nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") ncolr = kern.declare(Var(Int, "ncolr")) loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res) kern.execute(Round(ncolr, ncol, COL)) kern.execute(loop) if "_s" in sfx: incx, incy = kern.variables("incx", "incy") StrideAllIndexedTerms(kern.stmts, x, c, incx) StrideAllIndexedTerms(kern.stmts, y, r, incy) kern.emit() kerns.append(kern) r, c, row, res = F.variables("r", "c", "row", "res") nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") incx, incy = F.variables("incx", "incy") F.execute(Round(r, nrow, ROW)) F.execute( If(UnitIncs(incx, incy), Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])), Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])), ) ) F.params = Params((UInt32, "flag")) + F.params remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res) remainder.init = None F.execute(remainder) StrideAllIndexedTerms(F, x, c, incx) StrideAllIndexedTerms(F, y, r, incy) F.emit() # ------------------------------------------------------------------------ # Code Generation if __name__ == "__main__": emit("#include \n") emit("#include \n") emit("#include \n") emitln() emit("/*********************************************************/\n") emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n") emit("/*********************************************************/\n") emitln(2) for kind in [Float64]: #[Float32, Float64]: trsv(kind) # syr(kind) # ger(kind) # gemv(kind) flush()