aboutsummaryrefslogtreecommitdiff
path: root/sys/libmath/gen1.py
diff options
context:
space:
mode:
Diffstat (limited to 'sys/libmath/gen1.py')
-rwxr-xr-xsys/libmath/gen1.py357
1 files changed, 357 insertions, 0 deletions
diff --git a/sys/libmath/gen1.py b/sys/libmath/gen1.py
new file mode 100755
index 0000000..b0f9ecc
--- /dev/null
+++ b/sys/libmath/gen1.py
@@ -0,0 +1,357 @@
+from C import *
+
+NUNROLL = 16
+def pkg(name):
+ return f"blas·{name}"
+
+def typeify(string, kind):
+ if (kind == Float32):
+ return f"{string}f"
+ if (kind == Float64):
+ return f"{string}d"
+
+def fini(func, loop, strided, calls, ret=[]):
+ func.execute(*calls[:2])
+ func = Strided(func, loop, NUNROLL//2, strided, *ret)
+ func.execute(*calls[2:])
+ func.emit()
+
+# ------------------------------------------------------------------------
+# Blas level 1 functions
+
+def copy(kind):
+ name = typeify("copy", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y")
+ ),
+ Vars(
+ (Int, "i")
+ )
+ )
+
+ len, x, y, i = F.variables("len", "x", "y", "i")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Set(Index(y, i), Index(x, i))
+ )
+
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+
+ # F.execute(*calls)
+ # F.emit()
+ fini(F, loop, F.params[1:3], calls)
+
+def swap(kind):
+ name = typeify("swap", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y")
+ ),
+ Vars(
+ (Int, "i"), (kind, "tmp")
+ )
+ )
+
+ len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Swap(Index(x, i), Index(y, i), tmp)
+ )
+
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+
+ # F.execute(*calls)
+ # F.emit()
+ fini(F, loop, F.params[1:3], calls)
+
+def axpby(kind):
+ name = typeify("axpby", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (kind, "b"), (Ptr(kind), "y")
+ ),
+ Vars(
+ (Int, "i"),
+ )
+ )
+
+ len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Set(Index(y, i),
+ Mul(b,
+ Add(Index(y, i),
+ Mul(a, Index(x, i)),
+ )
+ )
+ )
+ )
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+
+ # F.execute(*calls)
+ # F.emit()
+ fini(F, loop, [F.params[2], F.params[4]], calls)
+
+def axpy(kind):
+ name = typeify("axpy", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Ptr(kind), "y")
+ ),
+ Vars(
+ (Int, "i"),
+ )
+ )
+
+ len, x, y, i, a = F.variables("len", "x", "y", "i", "a")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Set(Index(y, i),
+ Add(Index(y, i),
+ Mul(a, Index(x, i)),
+ )
+ )
+ )
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+
+ fini(F, loop, F.params[2:4], calls)
+ # F.execute(*calls)
+ # F.emit()
+
+def argminmax(kind):
+ for operation in [("min", LT), ("max", GT)]:
+ name = typeify(f"arg{operation[0]}", kind)
+ F = Func(pkg(name), Int,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"),
+ ),
+ Vars(
+ (Int, "i"), (Int, "ix"), (kind, operation[0])
+ )
+ )
+ len, x, i, ix, op = F.variables("len", "x", "i", "ix", operation[0])
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ If(operation[1](Index(x, i), op),
+ Block(
+ Set(ix, i),
+ Set(op, Index(x, ix)),
+ )
+ )
+ )
+
+ ret = (ix, lambda ires, icur, node:
+ If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))),
+ Block(Set(Index(node, ires), Index(node, icur)))
+ )
+ )
+
+ kernel, calls = Unroll(name, loop, NUNROLL, *ret)
+ kernel.emit()
+
+ # F.execute(*calls, Return(ix))
+ # F.emit()
+ fini(F, loop, [F.params[1]], calls[:2] + [Set(op, Index(x, ix))] + calls[2:] + [Return(ix)], ret)
+
+def dot(kind):
+ nm = typeify("dot", kind)
+ F = Func(pkg(nm), kind,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y")
+ ),
+ Vars(
+ (Int, "i"), (kind, "sum"),
+ )
+ )
+ len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ AddSet(sum, Mul(Index(x, i), Index(y, i)))
+ )
+ ret = (sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur))))
+
+ kernel, calls = Unroll(nm, loop, NUNROLL, *ret)
+ kernel.emit()
+
+ # F.execute(*calls, Return(sum))
+ fini(F, loop, F.params[1:3], calls + [Return(sum)], ret)
+
+def norm(kind):
+ nm = typeify("norm", kind)
+ F = Func(pkg(nm), kind,
+ Params(
+ (Int, "len"), (Ptr(kind), "x")
+ ),
+ Vars(
+ (Int, "i"), (kind, "nrm"),
+ )
+ )
+ len, x, i, nrm = F.variables("len", "x", "i", "nrm")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ AddSet(nrm, Mul(Index(x, i), Index(x, i)))
+ )
+
+ ret = (nrm, lambda ires, icur, node:
+ StmtExpr(AddSet(Index(node, ires), Index(node, icur)))
+ )
+
+ kernel, calls = Unroll(nm, loop, NUNROLL, *ret)
+ kernel.emit()
+
+ if kind == Float64:
+ sqrt = Func("math·sqrt", kind)
+ elif kind == Float32:
+ sqrt = Func("math·sqrtf", kind)
+ else:
+ raise ValueError(f"no sqrt for type {kind}")
+
+ # F.execute(*calls, Return(Call(sqrt, [nrm])))
+ # F.emit()
+ fini(F, loop, [F.params[1]], calls + [Return(Call(sqrt, [nrm]))], ret)
+
+def sum(kind):
+ name = typeify("sum", kind)
+ F = Func(pkg(name), kind,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"),
+ ),
+ Vars(
+ (Int, "i"), (kind, "sum")
+ )
+ )
+
+ len, x, i, sum = F.variables("len", "x", "i", "sum")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ AddSet(sum, Index(x, i))
+ )
+
+ ret = (sum, lambda ires, icur, node:
+ StmtExpr(AddSet(Index(node, ires), Index(node, icur)))
+ )
+
+ kernel, calls = Unroll(name, loop, NUNROLL, *ret)
+ kernel.emit()
+
+ fini(F, loop, [F.params[1]], calls + [Return(sum)], ret)
+ # F.execute(*calls, Return(sum))
+ # F.emit()
+
+def scale(kind):
+ name = typeify("scale", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"), (kind, "a")
+ ),
+ Vars(
+ (Int, "i"),
+ )
+ )
+
+ len, a, x, i = F.variables("len", "a", "x", "i")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(Set(Index(x, i), Mul(a, Index(x, i))))
+
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+
+ fini(F, loop, [F.params[1]], calls)
+ # F.execute(*calls)
+ # F.emit()
+
+def rot(kind):
+ name = typeify("rot", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (kind, "cos"), (kind, "sin")
+ ),
+ Vars(
+ (Int, "i"), (kind, "tmp")
+ )
+ )
+
+ len, x, y, i, tmp, cos, sin = F.variables("len", "x", "y", "i", "tmp", "cos", "sin")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Comma(Set(tmp, Index(x, i)),
+ Comma(
+ Set(Index(x, i), Add(Mul(cos, Index(x, i)), Mul(sin, Index(y, i)))),
+ Set(Index(y, i), Sub(Mul(cos, Index(y, i)), Mul(sin, tmp)))
+ )
+ )
+ )
+
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+
+ fini(F, loop, [F.params[1], F.params[2]], calls)
+ # F.execute(*calls)
+ # F.emit()
+
+def rotm(kind):
+ name = typeify("rotg", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (Array(kind, 5), "H")
+ ),
+ Vars(
+ (Int, "i"), (kind, "tmp")
+ )
+ )
+
+ len, x, y, i, tmp, H = F.variables("len", "x", "y", "i", "tmp", "H")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Comma(Set(tmp, Index(x, i)),
+ Comma(
+ Set(Index(x, i), Add(Mul(Index(H, I(1)), Index(x, i)), Mul(Index(H, I(2)), Index(y, i)))),
+ Set(Index(y, i), Add(Mul(Index(H, I(3)), Index(y, i)), Mul(Index(H, I(4)), tmp)))
+ )
+ )
+ )
+
+ kernel, calls = Unroll(name, loop, NUNROLL)
+ kernel.emit()
+ fini(F, loop, F.params[1:3], calls)
+
+if __name__ == "__main__":
+ emit("#include <u.h>\n")
+ emit("#include <libn.h>\n")
+ emit("#include <libmath.h>\n")
+ emitln()
+ emit("/*********************************************************/\n")
+ emit("/* THIS CODE IS GENERATED BY GEN1.PY! DON'T EDIT BY HAND */\n")
+ emit("/*********************************************************/\n")
+ emitln(2)
+
+ # TODO: Implement rotg/rotmg
+ for kind in [Float32, Float64]:
+ argminmax(kind)
+ copy(kind)
+ axpy(kind)
+ axpby(kind)
+ dot(kind)
+ sum(kind)
+ norm(kind)
+ scale(kind)
+ rot(kind)
+ rotm(kind)
+
+ flush()