aboutsummaryrefslogtreecommitdiff
path: root/sys/libmath/gen1.py
diff options
context:
space:
mode:
Diffstat (limited to 'sys/libmath/gen1.py')
-rwxr-xr-xsys/libmath/gen1.py360
1 files changed, 0 insertions, 360 deletions
diff --git a/sys/libmath/gen1.py b/sys/libmath/gen1.py
deleted file mode 100755
index 936bc50..0000000
--- a/sys/libmath/gen1.py
+++ /dev/null
@@ -1,360 +0,0 @@
-#!/bin/python
-
-from C import *
-
-NUNROLL = 16
-def pkg(name):
- return f"blas·{name}"
-
-def typeify(string, kind):
- if (kind == Float32):
- return f"{string}f"
- if (kind == Float64):
- return f"{string}d"
-
-def fini(func, loop, strided, calls, ret=[]):
- func.execute(*calls[:2])
- func, scall = Strided(func, loop, NUNROLL//2, strided, *ret)
- calls[2] = scall[0]
- func.execute(*calls[2:])
- func.emit()
-
-# ------------------------------------------------------------------------
-# Blas level 1 functions
-
-def copy(kind):
- name = typeify("copy", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y")
- ),
- Vars(
- (Int, "i")
- )
- )
-
- len, x, y, i = F.variables("len", "x", "y", "i")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Set(Index(y, i), Index(x, i))
- )
-
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
-
- # F.execute(*calls)
- # F.emit()
- fini(F, loop, F.params[1:3], calls)
-
-def swap(kind):
- name = typeify("swap", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y")
- ),
- Vars(
- (Int, "i"), (kind, "tmp")
- )
- )
-
- len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Swap(Index(x, i), Index(y, i), tmp)
- )
-
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
-
- # F.execute(*calls)
- # F.emit()
- fini(F, loop, F.params[1:3], calls)
-
-def axpby(kind):
- name = typeify("axpby", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (kind, "b"), (Ptr(kind), "y")
- ),
- Vars(
- (Int, "i"),
- )
- )
-
- len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Set(Index(y, i),
- Mul(b,
- Add(Index(y, i),
- Mul(a, Index(x, i)),
- )
- )
- )
- )
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
-
- # F.execute(*calls)
- # F.emit()
- fini(F, loop, [F.params[2], F.params[4]], calls)
-
-def axpy(kind):
- name = typeify("axpy", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Ptr(kind), "y")
- ),
- Vars(
- (Int, "i"),
- )
- )
-
- len, x, y, i, a = F.variables("len", "x", "y", "i", "a")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Set(Index(y, i),
- Add(Index(y, i),
- Mul(a, Index(x, i)),
- )
- )
- )
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
-
- fini(F, loop, F.params[2:4], calls)
- # F.execute(*calls)
- # F.emit()
-
-def argminmax(kind):
- for operation in [("min", LT), ("max", GT)]:
- name = typeify(f"arg{operation[0]}", kind)
- F = Func(pkg(name), Int,
- Params(
- (Int, "len"), (Ptr(kind), "x"),
- ),
- Vars(
- (Int, "i"), (Int, "ix"), (kind, operation[0])
- )
- )
- len, x, i, ix, op = F.variables("len", "x", "i", "ix", operation[0])
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- If(operation[1](Index(x, i), op),
- Block(
- Set(ix, i),
- Set(op, Index(x, ix)),
- )
- )
- )
-
- ret = (ix, lambda ires, icur, node:
- If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))),
- Block(Set(Index(node, ires), Index(node, icur)))
- )
- )
-
- kernel, calls = Unroll(name, loop, NUNROLL, *ret)
- kernel.emit()
-
- # F.execute(*calls, Return(ix))
- # F.emit()
- fini(F, loop, [F.params[1]], calls[:2] + [Set(op, Index(x, ix))] + calls[2:] + [Return(ix)], ret)
-
-def dot(kind):
- nm = typeify("dot", kind)
- F = Func(pkg(nm), kind,
- Params(
- (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y")
- ),
- Vars(
- (Int, "i"), (kind, "sum"),
- )
- )
- len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- AddSet(sum, Mul(Index(x, i), Index(y, i)))
- )
- ret = (sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur))))
-
- kernel, calls = Unroll(nm, loop, NUNROLL, *ret)
- kernel.emit()
-
- # F.execute(*calls, Return(sum))
- fini(F, loop, F.params[1:3], calls + [Return(sum)], ret)
-
-def norm(kind):
- nm = typeify("norm", kind)
- F = Func(pkg(nm), kind,
- Params(
- (Int, "len"), (Ptr(kind), "x")
- ),
- Vars(
- (Int, "i"), (kind, "nrm"),
- )
- )
- len, x, i, nrm = F.variables("len", "x", "i", "nrm")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- AddSet(nrm, Mul(Index(x, i), Index(x, i)))
- )
-
- ret = (nrm, lambda ires, icur, node:
- StmtExpr(AddSet(Index(node, ires), Index(node, icur)))
- )
-
- kernel, calls = Unroll(nm, loop, NUNROLL, *ret)
- kernel.emit()
-
- if kind == Float64:
- sqrt = Func("math·sqrt", kind)
- elif kind == Float32:
- sqrt = Func("math·sqrtf", kind)
- else:
- raise ValueError(f"no sqrt for type {kind}")
-
- # F.execute(*calls, Return(Call(sqrt, [nrm])))
- # F.emit()
- fini(F, loop, [F.params[1]], calls + [Return(Call(sqrt, [nrm]))], ret)
-
-def sum(kind):
- name = typeify("sum", kind)
- F = Func(pkg(name), kind,
- Params(
- (Int, "len"), (Ptr(kind), "x"),
- ),
- Vars(
- (Int, "i"), (kind, "sum")
- )
- )
-
- len, x, i, sum = F.variables("len", "x", "i", "sum")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- AddSet(sum, Index(x, i))
- )
-
- ret = (sum, lambda ires, icur, node:
- StmtExpr(AddSet(Index(node, ires), Index(node, icur)))
- )
-
- kernel, calls = Unroll(name, loop, NUNROLL, *ret)
- kernel.emit()
-
- fini(F, loop, [F.params[1]], calls + [Return(sum)], ret)
- # F.execute(*calls, Return(sum))
- # F.emit()
-
-def scale(kind):
- name = typeify("scale", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (Ptr(kind), "x"), (kind, "a")
- ),
- Vars(
- (Int, "i"),
- )
- )
-
- len, a, x, i = F.variables("len", "a", "x", "i")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(Set(Index(x, i), Mul(a, Index(x, i))))
-
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
-
- fini(F, loop, [F.params[1]], calls)
- # F.execute(*calls)
- # F.emit()
-
-def rot(kind):
- name = typeify("rot", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (kind, "cos"), (kind, "sin")
- ),
- Vars(
- (Int, "i"), (kind, "tmp")
- )
- )
-
- len, x, y, i, tmp, cos, sin = F.variables("len", "x", "y", "i", "tmp", "cos", "sin")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Comma(Set(tmp, Index(x, i)),
- Comma(
- Set(Index(x, i), Add(Mul(cos, Index(x, i)), Mul(sin, Index(y, i)))),
- Set(Index(y, i), Sub(Mul(cos, Index(y, i)), Mul(sin, tmp)))
- )
- )
- )
-
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
-
- fini(F, loop, [F.params[1], F.params[2]], calls)
- # F.execute(*calls)
- # F.emit()
-
-def rotm(kind):
- name = typeify("rotg", kind)
- F = Func(pkg(name), Void,
- Params(
- (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (Array(kind, 5), "H")
- ),
- Vars(
- (Int, "i"), (kind, "tmp")
- )
- )
-
- len, x, y, i, tmp, H = F.variables("len", "x", "y", "i", "tmp", "H")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Comma(Set(tmp, Index(x, i)),
- Comma(
- Set(Index(x, i), Add(Mul(Index(H, I(1)), Index(x, i)), Mul(Index(H, I(2)), Index(y, i)))),
- Set(Index(y, i), Add(Mul(Index(H, I(3)), Index(y, i)), Mul(Index(H, I(4)), tmp)))
- )
- )
- )
-
- kernel, calls = Unroll(name, loop, NUNROLL)
- kernel.emit()
- fini(F, loop, F.params[1:3], calls)
-
-if __name__ == "__main__":
- emit("#include <u.h>\n")
- emit("#include <libn.h>\n")
- emit("#include <libmath.h>\n")
- emitln()
- emit("/*********************************************************/\n")
- emit("/* THIS CODE IS GENERATED BY GEN1.PY! DON'T EDIT BY HAND */\n")
- emit("/*********************************************************/\n")
- emitln(2)
-
- # TODO: Implement rotg/rotmg
- for kind in [Float32, Float64]:
- argminmax(kind)
- copy(kind)
- axpy(kind)
- axpby(kind)
- dot(kind)
- sum(kind)
- norm(kind)
- scale(kind)
- rot(kind)
- rotm(kind)
-
- flush()