#!/bin/python from C import * NUNROLL = 16 def pkg(name): return f"blas·{name}" def typeify(string, kind): if (kind == Float32): return f"{string}f" if (kind == Float64): return f"{string}d" def fini(func, loop, strided, calls, ret=[]): func.execute(*calls[:2]) func, scall = Strided(func, loop, NUNROLL//2, strided, *ret) calls[2] = scall[0] func.execute(*calls[2:]) func.emit() # ------------------------------------------------------------------------ # Blas level 1 functions def copy(kind): name = typeify("copy", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") ), Vars( (Int, "i") ) ) len, x, y, i = F.variables("len", "x", "y", "i") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Set(Index(y, i), Index(x, i)) ) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() # F.execute(*calls) # F.emit() fini(F, loop, F.params[1:3], calls) def swap(kind): name = typeify("swap", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") ), Vars( (Int, "i"), (kind, "tmp") ) ) len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Swap(Index(x, i), Index(y, i), tmp) ) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() # F.execute(*calls) # F.emit() fini(F, loop, F.params[1:3], calls) def axpby(kind): name = typeify("axpby", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (kind, "b"), (Ptr(kind), "y") ), Vars( (Int, "i"), ) ) len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Set(Index(y, i), Mul(b, Add(Index(y, i), Mul(a, Index(x, i)), ) ) ) ) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() # F.execute(*calls) # F.emit() fini(F, loop, [F.params[2], F.params[4]], calls) def axpy(kind): name = typeify("axpy", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (kind, "a"), (Ptr(kind), "x"), (Ptr(kind), "y") ), Vars( (Int, "i"), ) ) len, x, y, i, a = F.variables("len", "x", "y", "i", "a") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Set(Index(y, i), Add(Index(y, i), Mul(a, Index(x, i)), ) ) ) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() fini(F, loop, F.params[2:4], calls) # F.execute(*calls) # F.emit() def argminmax(kind): for operation in [("min", LT), ("max", GT)]: name = typeify(f"arg{operation[0]}", kind) F = Func(pkg(name), Int, Params( (Int, "len"), (Ptr(kind), "x"), ), Vars( (Int, "i"), (Int, "ix"), (kind, operation[0]) ) ) len, x, i, ix, op = F.variables("len", "x", "i", "ix", operation[0]) loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( If(operation[1](Index(x, i), op), Block( Set(ix, i), Set(op, Index(x, ix)), ) ) ) ret = (ix, lambda ires, icur, node: If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))), Block(Set(Index(node, ires), Index(node, icur))) ) ) kernel, calls = Unroll(name, loop, NUNROLL, *ret) kernel.emit() # F.execute(*calls, Return(ix)) # F.emit() fini(F, loop, [F.params[1]], calls[:2] + [Set(op, Index(x, ix))] + calls[2:] + [Return(ix)], ret) def dot(kind): nm = typeify("dot", kind) F = Func(pkg(nm), kind, Params( (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y") ), Vars( (Int, "i"), (kind, "sum"), ) ) len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( AddSet(sum, Mul(Index(x, i), Index(y, i))) ) ret = (sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur)))) kernel, calls = Unroll(nm, loop, NUNROLL, *ret) kernel.emit() # F.execute(*calls, Return(sum)) fini(F, loop, F.params[1:3], calls + [Return(sum)], ret) def norm(kind): nm = typeify("norm", kind) F = Func(pkg(nm), kind, Params( (Int, "len"), (Ptr(kind), "x") ), Vars( (Int, "i"), (kind, "nrm"), ) ) len, x, i, nrm = F.variables("len", "x", "i", "nrm") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( AddSet(nrm, Mul(Index(x, i), Index(x, i))) ) ret = (nrm, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur))) ) kernel, calls = Unroll(nm, loop, NUNROLL, *ret) kernel.emit() if kind == Float64: sqrt = Func("math·sqrt", kind) elif kind == Float32: sqrt = Func("math·sqrtf", kind) else: raise ValueError(f"no sqrt for type {kind}") # F.execute(*calls, Return(Call(sqrt, [nrm]))) # F.emit() fini(F, loop, [F.params[1]], calls + [Return(Call(sqrt, [nrm]))], ret) def sum(kind): name = typeify("sum", kind) F = Func(pkg(name), kind, Params( (Int, "len"), (Ptr(kind), "x"), ), Vars( (Int, "i"), (kind, "sum") ) ) len, x, i, sum = F.variables("len", "x", "i", "sum") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( AddSet(sum, Index(x, i)) ) ret = (sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur))) ) kernel, calls = Unroll(name, loop, NUNROLL, *ret) kernel.emit() fini(F, loop, [F.params[1]], calls + [Return(sum)], ret) # F.execute(*calls, Return(sum)) # F.emit() def scale(kind): name = typeify("scale", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (Ptr(kind), "x"), (kind, "a") ), Vars( (Int, "i"), ) ) len, a, x, i = F.variables("len", "a", "x", "i") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute(Set(Index(x, i), Mul(a, Index(x, i)))) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() fini(F, loop, [F.params[1]], calls) # F.execute(*calls) # F.emit() def rot(kind): name = typeify("rot", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (kind, "cos"), (kind, "sin") ), Vars( (Int, "i"), (kind, "tmp") ) ) len, x, y, i, tmp, cos, sin = F.variables("len", "x", "y", "i", "tmp", "cos", "sin") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Comma(Set(tmp, Index(x, i)), Comma( Set(Index(x, i), Add(Mul(cos, Index(x, i)), Mul(sin, Index(y, i)))), Set(Index(y, i), Sub(Mul(cos, Index(y, i)), Mul(sin, tmp))) ) ) ) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() fini(F, loop, [F.params[1], F.params[2]], calls) # F.execute(*calls) # F.emit() def rotm(kind): name = typeify("rotg", kind) F = Func(pkg(name), Void, Params( (Int, "len"), (Ptr(kind), "x"), (Ptr(kind), "y"), (Array(kind, 5), "H") ), Vars( (Int, "i"), (kind, "tmp") ) ) len, x, y, i, tmp, H = F.variables("len", "x", "y", "i", "tmp", "H") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Comma(Set(tmp, Index(x, i)), Comma( Set(Index(x, i), Add(Mul(Index(H, I(1)), Index(x, i)), Mul(Index(H, I(2)), Index(y, i)))), Set(Index(y, i), Add(Mul(Index(H, I(3)), Index(y, i)), Mul(Index(H, I(4)), tmp))) ) ) ) kernel, calls = Unroll(name, loop, NUNROLL) kernel.emit() fini(F, loop, F.params[1:3], calls) if __name__ == "__main__": emit("#include \n") emit("#include \n") emit("#include \n") emitln() emit("/*********************************************************/\n") emit("/* THIS CODE IS GENERATED BY GEN1.PY! DON'T EDIT BY HAND */\n") emit("/*********************************************************/\n") emitln(2) # TODO: Implement rotg/rotmg for kind in [Float32, Float64]: argminmax(kind) copy(kind) axpy(kind) axpby(kind) dot(kind) sum(kind) norm(kind) scale(kind) rot(kind) rotm(kind) flush()