diff options
Diffstat (limited to 'sys/libmath/gen1.py')
-rwxr-xr-x | sys/libmath/gen1.py | 360 |
1 files changed, 0 insertions, 360 deletions
diff --git a/sys/libmath/gen1.py b/sys/libmath/gen1.py deleted file mode 100755 index 936bc50..0000000 --- a/sys/libmath/gen1.py +++ /dev/null @@ -1,360 +0,0 @@ -#!/bin/python - -from C import * - -NUNROLL = 16 -def pkg(name): - return f"blas·{name}" - -def typeify(string, kind): - if (kind == Float32): - return f"{string}f" - if (kind == Float64): - return f"{string}d" - -def fini(func, loop, strided, calls, ret=[]): - func.execute(*calls[:2]) - func, scall = Strided(func, loop, NUNROLL//2, strided, *ret) - calls[2] = scall[0] - func.execute(*calls[2:]) - func.emit() - -# ------------------------------------------------------------------------ -# Blas level 1 functions - -def copy(kind): - name = typeify("copy", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") - ), - Vars( - (Int, "i") - ) - ) - - len, x, y, i = F.variables("len", "x", "y", "i") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Set(Index(y, i), Index(x, i)) - ) - - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - - # F.execute(*calls) - # F.emit() - fini(F, loop, F.params[1:3], calls) - -def swap(kind): - name = typeify("swap", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") - ), - Vars( - (Int, "i"), (kind, "tmp") - ) - ) - - len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Swap(Index(x, i), Index(y, i), tmp) - ) - - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - - # F.execute(*calls) - # F.emit() - fini(F, loop, F.params[1:3], calls) - -def axpby(kind): - name = typeify("axpby", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (kind, "b"), (Ptr(kind), "y") - ), - Vars( - (Int, "i"), - ) - ) - - len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Set(Index(y, i), - Mul(b, - Add(Index(y, i), - Mul(a, Index(x, i)), - ) - ) - ) - ) - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - - # F.execute(*calls) - # F.emit() - fini(F, loop, [F.params[2], F.params[4]], calls) - -def axpy(kind): - name = typeify("axpy", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Ptr(kind), "y") - ), - Vars( - (Int, "i"), - ) - ) - - len, x, y, i, a = F.variables("len", "x", "y", "i", "a") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Set(Index(y, i), - Add(Index(y, i), - Mul(a, Index(x, i)), - ) - ) - ) - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - - fini(F, loop, F.params[2:4], calls) - # F.execute(*calls) - # F.emit() - -def argminmax(kind): - for operation in [("min", LT), ("max", GT)]: - name = typeify(f"arg{operation[0]}", kind) - F = Func(pkg(name), Int, - Params( - (Int, "len"), (Ptr(kind), "x"), - ), - Vars( - (Int, "i"), (Int, "ix"), (kind, operation[0]) - ) - ) - len, x, i, ix, op = F.variables("len", "x", "i", "ix", operation[0]) - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - If(operation[1](Index(x, i), op), - Block( - Set(ix, i), - Set(op, Index(x, ix)), - ) - ) - ) - - ret = (ix, lambda ires, icur, node: - If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))), - Block(Set(Index(node, ires), Index(node, icur))) - ) - ) - - kernel, calls = Unroll(name, loop, NUNROLL, *ret) - kernel.emit() - - # F.execute(*calls, Return(ix)) - # F.emit() - fini(F, loop, [F.params[1]], calls[:2] + [Set(op, Index(x, ix))] + calls[2:] + [Return(ix)], ret) - -def dot(kind): - nm = typeify("dot", kind) - F = Func(pkg(nm), kind, - Params( - (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") - ), - Vars( - (Int, "i"), (kind, "sum"), - ) - ) - len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - AddSet(sum, Mul(Index(x, i), Index(y, i))) - ) - ret = (sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur)))) - - kernel, calls = Unroll(nm, loop, NUNROLL, *ret) - kernel.emit() - - # F.execute(*calls, Return(sum)) - fini(F, loop, F.params[1:3], calls + [Return(sum)], ret) - -def norm(kind): - nm = typeify("norm", kind) - F = Func(pkg(nm), kind, - Params( - (Int, "len"), (Ptr(kind), "x") - ), - Vars( - (Int, "i"), (kind, "nrm"), - ) - ) - len, x, i, nrm = F.variables("len", "x", "i", "nrm") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - AddSet(nrm, Mul(Index(x, i), Index(x, i))) - ) - - ret = (nrm, lambda ires, icur, node: - StmtExpr(AddSet(Index(node, ires), Index(node, icur))) - ) - - kernel, calls = Unroll(nm, loop, NUNROLL, *ret) - kernel.emit() - - if kind == Float64: - sqrt = Func("math·sqrt", kind) - elif kind == Float32: - sqrt = Func("math·sqrtf", kind) - else: - raise ValueError(f"no sqrt for type {kind}") - - # F.execute(*calls, Return(Call(sqrt, [nrm]))) - # F.emit() - fini(F, loop, [F.params[1]], calls + [Return(Call(sqrt, [nrm]))], ret) - -def sum(kind): - name = typeify("sum", kind) - F = Func(pkg(name), kind, - Params( - (Int, "len"), (Ptr(kind), "x"), - ), - Vars( - (Int, "i"), (kind, "sum") - ) - ) - - len, x, i, sum = F.variables("len", "x", "i", "sum") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - AddSet(sum, Index(x, i)) - ) - - ret = (sum, lambda ires, icur, node: - StmtExpr(AddSet(Index(node, ires), Index(node, icur))) - ) - - kernel, calls = Unroll(name, loop, NUNROLL, *ret) - kernel.emit() - - fini(F, loop, [F.params[1]], calls + [Return(sum)], ret) - # F.execute(*calls, Return(sum)) - # F.emit() - -def scale(kind): - name = typeify("scale", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (Ptr(kind), "x"), (kind, "a") - ), - Vars( - (Int, "i"), - ) - ) - - len, a, x, i = F.variables("len", "a", "x", "i") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute(Set(Index(x, i), Mul(a, Index(x, i)))) - - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - - fini(F, loop, [F.params[1]], calls) - # F.execute(*calls) - # F.emit() - -def rot(kind): - name = typeify("rot", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (kind, "cos"), (kind, "sin") - ), - Vars( - (Int, "i"), (kind, "tmp") - ) - ) - - len, x, y, i, tmp, cos, sin = F.variables("len", "x", "y", "i", "tmp", "cos", "sin") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Comma(Set(tmp, Index(x, i)), - Comma( - Set(Index(x, i), Add(Mul(cos, Index(x, i)), Mul(sin, Index(y, i)))), - Set(Index(y, i), Sub(Mul(cos, Index(y, i)), Mul(sin, tmp))) - ) - ) - ) - - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - - fini(F, loop, [F.params[1], F.params[2]], calls) - # F.execute(*calls) - # F.emit() - -def rotm(kind): - name = typeify("rotg", kind) - F = Func(pkg(name), Void, - Params( - (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (Array(kind, 5), "H") - ), - Vars( - (Int, "i"), (kind, "tmp") - ) - ) - - len, x, y, i, tmp, H = F.variables("len", "x", "y", "i", "tmp", "H") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Comma(Set(tmp, Index(x, i)), - Comma( - Set(Index(x, i), Add(Mul(Index(H, I(1)), Index(x, i)), Mul(Index(H, I(2)), Index(y, i)))), - Set(Index(y, i), Add(Mul(Index(H, I(3)), Index(y, i)), Mul(Index(H, I(4)), tmp))) - ) - ) - ) - - kernel, calls = Unroll(name, loop, NUNROLL) - kernel.emit() - fini(F, loop, F.params[1:3], calls) - -if __name__ == "__main__": - emit("#include <u.h>\n") - emit("#include <libn.h>\n") - emit("#include <libmath.h>\n") - emitln() - emit("/*********************************************************/\n") - emit("/* THIS CODE IS GENERATED BY GEN1.PY! DON'T EDIT BY HAND */\n") - emit("/*********************************************************/\n") - emitln(2) - - # TODO: Implement rotg/rotmg - for kind in [Float32, Float64]: - argminmax(kind) - copy(kind) - axpy(kind) - axpby(kind) - dot(kind) - sum(kind) - norm(kind) - scale(kind) - rot(kind) - rotm(kind) - - flush() |