aboutsummaryrefslogtreecommitdiff
path: root/sys/libmath/gen2.py
diff options
context:
space:
mode:
Diffstat (limited to 'sys/libmath/gen2.py')
-rwxr-xr-xsys/libmath/gen2.py410
1 files changed, 0 insertions, 410 deletions
diff --git a/sys/libmath/gen2.py b/sys/libmath/gen2.py
deleted file mode 100755
index 2afbe1d..0000000
--- a/sys/libmath/gen2.py
+++ /dev/null
@@ -1,410 +0,0 @@
-from C import *
-import copy
-
-ROW = 4
-COL = 4
-
-def pkg(name):
- return f"blasĀ·{name}"
-
-def typeify(string, kind):
- if (kind == Float32):
- return f"{string}f"
- if (kind == Float64):
- return f"{string}d"
-
-# ------------------------------------------------------------------------
-# Helpers (abandoning the automatic unroll from level 1)
-
-def toarray(len: int, *args):
- return [Param(Array(arg.type, len), arg.name) for arg in args]
-
-def TryIndex(x, i):
- if IsArrayType(x.var.type):
- return Index(x, i)
- return x
-
-def AddElts(root, *vars):
- for var in vars:
- root = Add(root ,var)
- return root
-
-def Round(store, number, by):
- return Set(store, And(number, Negate(I(by-1))))
-
-def UnitIncs(root, *incs):
- root = EQ(root, I(1))
- for inc in incs:
- root = AndAnd(root, EQ(inc, I(1)))
- return root
-
-def IsInc(p):
- return p.name != "incx" and p.name != "incy"
-
-def Identity(p):
- return True
-
-def FilterParams(params, func):
- return [p for p in params if func(p)]
-
-def StrideAllIndexedTerms(stmts, var, itor, inc):
- def is_hit(x):
- if isinstance(x, Index):
- if isinstance(x.i, BinaryOp) and x.x == var:
- return x.i.l == itor
-
- return False
-
- terms = []
- for stmt in stmts:
- Visit(stmt, lambda node: Filter(node, is_hit, terms))
-
- for term in terms:
- term.i = Mul(Paren(term.i), inc)
-
-def AsStrided(stmts, var, itor, inc):
- def increment(x):
- if isinstance(x, Index):
- if isinstance(x.i, BinaryOp) and x.x == var:
- return Index(x.x, Mul(Paren(x.i), inc))
-
- return copy.copy(x)
-
- if isinstance(stmts, Block):
- return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts])
- elif isinstance(stmts, list):
- return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts]
- else:
- raise TypeError("unrecognized stmts type")
-
-class Iter(object):
- def __init__(self, it, end, len, inc):
- self.it = it
- self.end = end
- self.len = len
- self.inc = inc
-
-def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]):
- def Step(it, inc):
- if inc == 1:
- return Inc(it)
- else:
- return AddSet(it, I(inc))
-
- return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
- Block(*[
- *[func(i) for func in Preamble for i in range(top.inc)],
- For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc),
- Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
- ),
- For(None, LT(bot.it, bot.len), Inc(bot.it),
- Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)])
- ),
- *[func(i) for func in Postamble for i in range(bot.inc)]
- ])
- )
-
-def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=lambda i: None, upper=True):
- # ---------------------------------------------------------------
- # Helper functions
-
- def Finish(j, upper=False):
- if upper:
- if j == 0:
- return Block(For(None, GT(bot.it, top.it), Dec(bot.it),
- Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)])
- ), Postamble(j)
- )
- return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Dec(bot.it), Postamble(j))
- if not upper:
- if j == 0:
- return Block(For(None, LT(bot.it, top.it), Inc(bot.it),
- Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)])
- ), Postamble(j))
- return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it), Postamble(j))
-
- def Step(it, inc, down=False):
- if not down:
- if inc == 1:
- return Inc(it)
- else:
- return AddSet(it, I(inc))
- else:
- if inc == 1:
- return Dec(it)
- else:
- return SubSet(it, I(inc))
-
- # ---------------------------------------------------------------
- # Main body
-
- if upper:
- return For(Set(top.it, Sub(top.end, I(1))), GE(top.it, I(0)), Step(top.it, top.inc, down=True),
- Block(*[
- *[func(i) for func in Preamble for i in range(top.inc)],
- Set(bot.end, Sub(Paren(EvenTo(Paren(Sub(top.it, Paren(Sub(bot.len, I(1))))), bot.inc)), Paren(Sub(bot.len, I(1))))),
- For(Set(bot.it, Sub(bot.len, I(1))), GT(bot.it, bot.end), Step(bot.it, bot.inc, down=True),
- Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
- ),
- *[ Finish(j, upper) if bot.inc > 1 else Postamble(j) for j in range(top.inc) ],
- # *[func(i) for func in Postamble for i in range(bot.inc)]
- ])
- )
- else:
- return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
- Block(*[
- *[func(i) for func in Preamble for i in range(top.inc)],
- Set(bot.end, EvenTo(top.it, bot.inc)),
- For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc),
- Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
- ),
- *[ Finish(j) if bot.inc > 1 else Postamble(j) for j in range(top.inc) ],
- # *[func(i) for func in Postamble for i in range(bot.inc)]
- ])
- )
-
-def Shift(x, i, uplo):
- if uplo == "upper":
- return Sub(x, I(i))
- if uplo == "lower":
- return Add(x, I(i))
- raise ValueError("unrecognized value")
-
-
-def ToKernel(name, loop):
- vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \
- VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body)
-
-def ExpandAdd(row, x, i: int, c: Emitter, inc: int, uplo):
- offset = Add(c, I(0))
- root = Mul(Index(Index(row, I(i)), offset), Index(x, offset))
- for n in range(1, inc):
- offset = Shift(c, I(n), uplo)
- root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset)))
- return root
-
-# ------------------------------------------------------------------------
-# Blas level 2 functions
-
-def trsv(kind):
- name = typeify("trsv", kind)
- F = Func(pkg(name), Void,
- Params(
- (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx")
- ),
- Vars(
- (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
- )
- )
-
- r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
- flag, _len, m, incm, x = F.variables("flag", "len", "m", "incm", "x")
- incx = F.variables("incx")
-
- rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c)
-
- template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c),
- Kernel = lambda r, c, i, inc: [AddSet(Index(res, I(i)), ExpandAdd(row, x, i, c, inc, upper))],
- Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Shift(r, i, upper)), incm))),
- lambda i: Set(Index(res, I(i)), I(0))],
- Postamble = lambda i: Set(Index(x, Shift(c, I(0), upper)), Div(Index(res, I(i)), Index(Index(row, I(i)), Shift(c, I(0), upper)))),
- upper = upper == "upper"
- )
-
- loop = template(1, 1, "upper")
- loop.emit()
-
-def syr(kind):
- name = typeify("syr", kind)
- F = Func(pkg(name), Void,
- Params(
- (UInt32, "flag"), (Int, "len"), (kind, "a"),
- (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"),
- ),
- Vars(
- (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
- )
- )
-
- r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
- flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x")
- incx = F.variables("incx")
-
- rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c)
-
- template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c),
- Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Shift(c, j, upper)), Mul(Index(res, I(i)), Index(x, Shift(c, j, upper)))) for j in range(inc)],
- Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Shift(r, i, upper)), incm))),
- lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Shift(r, i, upper))))],
- upper = upper == "upper"
- )
-
- blocks = []
- for layout in ["lower", "upper"]:
- floop = template(2, 2, layout)
- sloop = template(2, 2, layout)
- sloop.body = AsStrided(sloop.body, x, c, incx)
-
- fini = template(1, 1, layout)
- fini.init = None
- fini.body = AsStrided(fini.body, x, c, incx)
- fini.cond = LT(r, _len)
-
- blocks.append(
- Block(
- If(UnitIncs(incx), Block(floop), Block(sloop)),
- fini,
- Return(),
- )
- )
- F.execute(
- Set(nr, EvenTo(_len, ROW)),
- If(flag, blocks[0], blocks[1])
- )
- F.emit()
-
-def ger(kind):
- name = typeify("ger", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "nrow"), (Int, "ncol"), (kind, "a"),
- (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"),
- ),
- Vars(
- (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
- )
- )
-
- r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
- nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y")
- incx, incy = F.variables("incx", "incy")
-
- rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc)
-
- template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc),
- Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)],
- Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
- lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
- )
-
- # loop = template(1, 1)
- # F.execute(loop)
- # F.emit()
- floop = template(ROW, COL)
- sloop = template(ROW, COL)
- sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy)
-
- fini = template(1, 2*COL)
- fini.init = None
- fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy)
- fini.cond = LT(r, nrow)
-
- F.execute(
- Set(nr, EvenTo(nrow, ROW)),
- Set(nc, EvenTo(ncol, COL)),
- If(UnitIncs(incx, incy), Block(floop), Block(sloop))
- )
- F.execute(fini)
- F.emit()
-
-def gemv(kind):
- name = typeify("gemv", kind)
- params = Params(
- (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"),
- (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy")
- )
- stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res"))
- F = Func(pkg(name), Void, params, stack)
-
- # ---------------------
- # Kernel
-
- def innerloop(rinc, cit, cend, cinc):
- return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)),
- Block(*[AddSet(TryIndex(res, I(i)),
- AddElts(*(Mul(
- Index(TryIndex(row, I(i)), Add(cit, I(j))),
- Index(x, Add(cit, I(j)))) for j in range(cinc)
- )
- )
- ) for i in range(rinc)])
- )
-
- def tryloop(rinc, cit, cend, cinc):
- if cinc > 1:
- loop = innerloop(rinc, cit, cend, 1)
- loop.init = None
- return loop
-
- def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res):
- return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)),
- Block(
- *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)],
- *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)],
- innerloop(rinc, cit, cend, cinc),
- tryloop(rinc, cit, clen, cinc),
- *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)]
- )
- )
-
- kerns = []
- for func, sfx in [(IsInc, ""), (Identity, "_s")]:
- kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True)
- r, c, row, res = kern.variables("r", "c", "row", "res")
- nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
-
- ncolr = kern.declare(Var(Int, "ncolr"))
- loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res)
-
- kern.execute(Round(ncolr, ncol, COL))
- kern.execute(loop)
- if "_s" in sfx:
- incx, incy = kern.variables("incx", "incy")
- StrideAllIndexedTerms(kern.stmts, x, c, incx)
- StrideAllIndexedTerms(kern.stmts, y, r, incy)
-
- kern.emit()
-
- kerns.append(kern)
-
- r, c, row, res = F.variables("r", "c", "row", "res")
- nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
- incx, incy = F.variables("incx", "incy")
- F.execute(Round(r, nrow, ROW))
- F.execute(
- If(UnitIncs(incx, incy),
- Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])),
- Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])),
- )
- )
-
- F.params = Params((UInt32, "flag")) + F.params
-
- remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res)
- remainder.init = None
- F.execute(remainder)
- StrideAllIndexedTerms(F, x, c, incx)
- StrideAllIndexedTerms(F, y, r, incy)
-
- F.emit()
-
-# ------------------------------------------------------------------------
-# Code Generation
-
-if __name__ == "__main__":
- emit("#include <u.h>\n")
- emit("#include <libn.h>\n")
- emit("#include <libmath.h>\n")
- emitln()
- emit("/*********************************************************/\n")
- emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n")
- emit("/*********************************************************/\n")
- emitln(2)
-
- for kind in [Float64]: #[Float32, Float64]:
- trsv(kind)
- # syr(kind)
- # ger(kind)
- # gemv(kind)
-
- flush()