diff options
author | Nicholas Noll <nbnoll@eml.cc> | 2020-05-15 10:52:11 -0700 |
---|---|---|
committer | Nicholas Noll <nbnoll@eml.cc> | 2020-05-15 10:52:11 -0700 |
commit | 66eb918a13b6607cc7bb615350a0e26f3670cd54 (patch) | |
tree | 91a9f9c3e086e778fc754aa299f223e59dad429c /sys/libmath/gen2.py | |
parent | 463ed852261da4d1dd1b859fa717a1d683306c9d (diff) |
factored out the common code of makefiles
Diffstat (limited to 'sys/libmath/gen2.py')
-rwxr-xr-x | sys/libmath/gen2.py | 122 |
1 files changed, 71 insertions, 51 deletions
diff --git a/sys/libmath/gen2.py b/sys/libmath/gen2.py index 6ce2a12..2afbe1d 100755 --- a/sys/libmath/gen2.py +++ b/sys/libmath/gen2.py @@ -104,40 +104,50 @@ def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]): ]) ) -def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True): - def Step(it, inc): - if inc == 1: - return Inc(it) - else: - return AddSet(it, I(inc)) - - def Finish(j): - if j == 0: - return For(None, LE(bot.it, top.it), Inc(bot.it), - Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) - ) +def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=lambda i: None, upper=True): + # --------------------------------------------------------------- + # Helper functions + + def Finish(j, upper=False): + if upper: + if j == 0: + return Block(For(None, GT(bot.it, top.it), Dec(bot.it), + Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) + ), Postamble(j) + ) + return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Dec(bot.it), Postamble(j)) + if not upper: + if j == 0: + return Block(For(None, LT(bot.it, top.it), Inc(bot.it), + Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) + ), Postamble(j)) + return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it), Postamble(j)) + + def Step(it, inc, down=False): + if not down: + if inc == 1: + return Inc(it) + else: + return AddSet(it, I(inc)) else: - return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) + if inc == 1: + return Dec(it) + else: + return SubSet(it, I(inc)) - def Start(j, end): - if j == end: - return For(None, LT(bot.it, bot.end), Inc(bot.it), - Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)]) - ) - else: - return Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) + # --------------------------------------------------------------- + # Main body if upper: - return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + return For(Set(top.it, Sub(top.end, I(1))), GE(top.it, I(0)), Step(top.it, top.inc, down=True), Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], - Set(bot.end, Add(Paren(EvenTo(Paren(Sub(top.end, top.it)), bot.inc)), top.it)), - Set(bot.it, top.it), - *[ Start(j, top.inc-1) for j in range(top.inc) if bot.inc > 1], - For(None, LT(bot.it, bot.len), Step(bot.it, bot.inc), + Set(bot.end, Sub(Paren(EvenTo(Paren(Sub(top.it, Paren(Sub(bot.len, I(1))))), bot.inc)), Paren(Sub(bot.len, I(1))))), + For(Set(bot.it, Sub(bot.len, I(1))), GT(bot.it, bot.end), Step(bot.it, bot.inc, down=True), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), - *[func(i) for func in Postamble for i in range(bot.inc)] + *[ Finish(j, upper) if bot.inc > 1 else Postamble(j) for j in range(top.inc) ], + # *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) else: @@ -145,27 +155,33 @@ def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True): Block(*[ *[func(i) for func in Preamble for i in range(top.inc)], Set(bot.end, EvenTo(top.it, bot.inc)), - For(Set(bot.it, I(0)), LE(bot.it, bot.end), Step(bot.it, bot.inc), + For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc), Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) ), - *[ Finish(j) for j in range(top.inc) if bot.inc > 1], - *[func(i) for func in Postamble for i in range(bot.inc)] + *[ Finish(j) if bot.inc > 1 else Postamble(j) for j in range(top.inc) ], + # *[func(i) for func in Postamble for i in range(bot.inc)] ]) ) +def Shift(x, i, uplo): + if uplo == "upper": + return Sub(x, I(i)) + if uplo == "lower": + return Add(x, I(i)) + raise ValueError("unrecognized value") def ToKernel(name, loop): vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \ VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body) -# def ExpandAdd(i: int, c: Emitter, inc: int): -# offset = Add(c, I(0)) -# root = Mul(Index(Index(row, I(i)), offset), Index(x, offset)) -# for n in range(1, inc): -# offset = Add(c, I(n)) -# root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset))) -# return root +def ExpandAdd(row, x, i: int, c: Emitter, inc: int, uplo): + offset = Add(c, I(0)) + root = Mul(Index(Index(row, I(i)), offset), Index(x, offset)) + for n in range(1, inc): + offset = Shift(c, I(n), uplo) + root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset))) + return root # ------------------------------------------------------------------------ # Blas level 2 functions @@ -181,20 +197,21 @@ def trsv(kind): ) ) - r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") - flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") - incx = F.variables("incx") + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + flag, _len, m, incm, x = F.variables("flag", "len", "m", "incm", "x") + incx = F.variables("incx") rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) - template = lambda inc_r, inc_c: TriangularLoop(rows(inc_r), cols(inc_c), - Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], - Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), - lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], - upper = True + template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), + Kernel = lambda r, c, i, inc: [AddSet(Index(res, I(i)), ExpandAdd(row, x, i, c, inc, upper))], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Shift(r, i, upper)), incm))), + lambda i: Set(Index(res, I(i)), I(0))], + Postamble = lambda i: Set(Index(x, Shift(c, I(0), upper)), Div(Index(res, I(i)), Index(Index(row, I(i)), Shift(c, I(0), upper)))), + upper = upper == "upper" ) - loop = template(1, 1) + loop = template(1, 1, "upper") loop.emit() def syr(kind): @@ -216,16 +233,16 @@ def syr(kind): rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), - Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], - Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), - lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Shift(c, j, upper)), Mul(Index(res, I(i)), Index(x, Shift(c, j, upper)))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Shift(r, i, upper)), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Shift(r, i, upper))))], upper = upper == "upper" ) blocks = [] for layout in ["lower", "upper"]: - floop = template(1, 1, layout) - sloop = template(1, 1, layout) + floop = template(2, 2, layout) + sloop = template(2, 2, layout) sloop.body = AsStrided(sloop.body, x, c, incx) fini = template(1, 1, layout) @@ -240,7 +257,10 @@ def syr(kind): Return(), ) ) - F.execute(If(flag, blocks[0], blocks[1])) + F.execute( + Set(nr, EvenTo(_len, ROW)), + If(flag, blocks[0], blocks[1]) + ) F.emit() def ger(kind): |