diff options
author | Nicholas Noll <nbnoll@eml.cc> | 2020-05-13 17:30:19 -0700 |
---|---|---|
committer | Nicholas Noll <nbnoll@eml.cc> | 2020-05-13 17:30:19 -0700 |
commit | d982e7c2fdebf560ccce193cb98b85d4fac28a45 (patch) | |
tree | b18902eea12a2d55a24994ca0681ca1a369631aa /sys/libmath/gen1.py | |
parent | c9d4b2d7dd1d9a46571e5d2b2cf6ce10a9d9ebea (diff) |
blas 1 generation code complete
Diffstat (limited to 'sys/libmath/gen1.py')
-rwxr-xr-x | sys/libmath/gen1.py | 357 |
1 files changed, 357 insertions, 0 deletions
diff --git a/sys/libmath/gen1.py b/sys/libmath/gen1.py new file mode 100755 index 0000000..b0f9ecc --- /dev/null +++ b/sys/libmath/gen1.py @@ -0,0 +1,357 @@ +from C import * + +NUNROLL = 16 +def pkg(name): + return f"blas·{name}" + +def typeify(string, kind): + if (kind == Float32): + return f"{string}f" + if (kind == Float64): + return f"{string}d" + +def fini(func, loop, strided, calls, ret=[]): + func.execute(*calls[:2]) + func = Strided(func, loop, NUNROLL//2, strided, *ret) + func.execute(*calls[2:]) + func.emit() + +# ------------------------------------------------------------------------ +# Blas level 1 functions + +def copy(kind): + name = typeify("copy", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") + ), + Vars( + (Int, "i") + ) + ) + + len, x, y, i = F.variables("len", "x", "y", "i") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Set(Index(y, i), Index(x, i)) + ) + + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + + # F.execute(*calls) + # F.emit() + fini(F, loop, F.params[1:3], calls) + +def swap(kind): + name = typeify("swap", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") + ), + Vars( + (Int, "i"), (kind, "tmp") + ) + ) + + len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Swap(Index(x, i), Index(y, i), tmp) + ) + + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + + # F.execute(*calls) + # F.emit() + fini(F, loop, F.params[1:3], calls) + +def axpby(kind): + name = typeify("axpby", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (kind, "b"), (Ptr(kind), "y") + ), + Vars( + (Int, "i"), + ) + ) + + len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Set(Index(y, i), + Mul(b, + Add(Index(y, i), + Mul(a, Index(x, i)), + ) + ) + ) + ) + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + + # F.execute(*calls) + # F.emit() + fini(F, loop, [F.params[2], F.params[4]], calls) + +def axpy(kind): + name = typeify("axpy", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Ptr(kind), "y") + ), + Vars( + (Int, "i"), + ) + ) + + len, x, y, i, a = F.variables("len", "x", "y", "i", "a") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Set(Index(y, i), + Add(Index(y, i), + Mul(a, Index(x, i)), + ) + ) + ) + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + + fini(F, loop, F.params[2:4], calls) + # F.execute(*calls) + # F.emit() + +def argminmax(kind): + for operation in [("min", LT), ("max", GT)]: + name = typeify(f"arg{operation[0]}", kind) + F = Func(pkg(name), Int, + Params( + (Int, "len"), (Ptr(kind), "x"), + ), + Vars( + (Int, "i"), (Int, "ix"), (kind, operation[0]) + ) + ) + len, x, i, ix, op = F.variables("len", "x", "i", "ix", operation[0]) + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + If(operation[1](Index(x, i), op), + Block( + Set(ix, i), + Set(op, Index(x, ix)), + ) + ) + ) + + ret = (ix, lambda ires, icur, node: + If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))), + Block(Set(Index(node, ires), Index(node, icur))) + ) + ) + + kernel, calls = Unroll(name, loop, NUNROLL, *ret) + kernel.emit() + + # F.execute(*calls, Return(ix)) + # F.emit() + fini(F, loop, [F.params[1]], calls[:2] + [Set(op, Index(x, ix))] + calls[2:] + [Return(ix)], ret) + +def dot(kind): + nm = typeify("dot", kind) + F = Func(pkg(nm), kind, + Params( + (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") + ), + Vars( + (Int, "i"), (kind, "sum"), + ) + ) + len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + AddSet(sum, Mul(Index(x, i), Index(y, i))) + ) + ret = (sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur)))) + + kernel, calls = Unroll(nm, loop, NUNROLL, *ret) + kernel.emit() + + # F.execute(*calls, Return(sum)) + fini(F, loop, F.params[1:3], calls + [Return(sum)], ret) + +def norm(kind): + nm = typeify("norm", kind) + F = Func(pkg(nm), kind, + Params( + (Int, "len"), (Ptr(kind), "x") + ), + Vars( + (Int, "i"), (kind, "nrm"), + ) + ) + len, x, i, nrm = F.variables("len", "x", "i", "nrm") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + AddSet(nrm, Mul(Index(x, i), Index(x, i))) + ) + + ret = (nrm, lambda ires, icur, node: + StmtExpr(AddSet(Index(node, ires), Index(node, icur))) + ) + + kernel, calls = Unroll(nm, loop, NUNROLL, *ret) + kernel.emit() + + if kind == Float64: + sqrt = Func("math·sqrt", kind) + elif kind == Float32: + sqrt = Func("math·sqrtf", kind) + else: + raise ValueError(f"no sqrt for type {kind}") + + # F.execute(*calls, Return(Call(sqrt, [nrm]))) + # F.emit() + fini(F, loop, [F.params[1]], calls + [Return(Call(sqrt, [nrm]))], ret) + +def sum(kind): + name = typeify("sum", kind) + F = Func(pkg(name), kind, + Params( + (Int, "len"), (Ptr(kind), "x"), + ), + Vars( + (Int, "i"), (kind, "sum") + ) + ) + + len, x, i, sum = F.variables("len", "x", "i", "sum") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + AddSet(sum, Index(x, i)) + ) + + ret = (sum, lambda ires, icur, node: + StmtExpr(AddSet(Index(node, ires), Index(node, icur))) + ) + + kernel, calls = Unroll(name, loop, NUNROLL, *ret) + kernel.emit() + + fini(F, loop, [F.params[1]], calls + [Return(sum)], ret) + # F.execute(*calls, Return(sum)) + # F.emit() + +def scale(kind): + name = typeify("scale", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (Ptr(kind), "x"), (kind, "a") + ), + Vars( + (Int, "i"), + ) + ) + + len, a, x, i = F.variables("len", "a", "x", "i") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute(Set(Index(x, i), Mul(a, Index(x, i)))) + + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + + fini(F, loop, [F.params[1]], calls) + # F.execute(*calls) + # F.emit() + +def rot(kind): + name = typeify("rot", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (kind, "cos"), (kind, "sin") + ), + Vars( + (Int, "i"), (kind, "tmp") + ) + ) + + len, x, y, i, tmp, cos, sin = F.variables("len", "x", "y", "i", "tmp", "cos", "sin") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Comma(Set(tmp, Index(x, i)), + Comma( + Set(Index(x, i), Add(Mul(cos, Index(x, i)), Mul(sin, Index(y, i)))), + Set(Index(y, i), Sub(Mul(cos, Index(y, i)), Mul(sin, tmp))) + ) + ) + ) + + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + + fini(F, loop, [F.params[1], F.params[2]], calls) + # F.execute(*calls) + # F.emit() + +def rotm(kind): + name = typeify("rotg", kind) + F = Func(pkg(name), Void, + Params( + (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (Array(kind, 5), "H") + ), + Vars( + (Int, "i"), (kind, "tmp") + ) + ) + + len, x, y, i, tmp, H = F.variables("len", "x", "y", "i", "tmp", "H") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) + loop.execute( + Comma(Set(tmp, Index(x, i)), + Comma( + Set(Index(x, i), Add(Mul(Index(H, I(1)), Index(x, i)), Mul(Index(H, I(2)), Index(y, i)))), + Set(Index(y, i), Add(Mul(Index(H, I(3)), Index(y, i)), Mul(Index(H, I(4)), tmp))) + ) + ) + ) + + kernel, calls = Unroll(name, loop, NUNROLL) + kernel.emit() + fini(F, loop, F.params[1:3], calls) + +if __name__ == "__main__": + emit("#include <u.h>\n") + emit("#include <libn.h>\n") + emit("#include <libmath.h>\n") + emitln() + emit("/*********************************************************/\n") + emit("/* THIS CODE IS GENERATED BY GEN1.PY! DON'T EDIT BY HAND */\n") + emit("/*********************************************************/\n") + emitln(2) + + # TODO: Implement rotg/rotmg + for kind in [Float32, Float64]: + argminmax(kind) + copy(kind) + axpy(kind) + axpby(kind) + dot(kind) + sum(kind) + norm(kind) + scale(kind) + rot(kind) + rotm(kind) + + flush() |