From acf54a4b990634e23a88b04363923c4b108a0304 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sun, 10 May 2020 20:25:46 -0700 Subject: extract compute kernel prototype --- lib/c.py | 431 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 384 insertions(+), 47 deletions(-) diff --git a/lib/c.py b/lib/c.py index 4f0ae29..54b8e14 100644 --- a/lib/c.py +++ b/lib/c.py @@ -4,10 +4,11 @@ from __future__ import annotations import os import sys -from abc import ABC, abstractmethod -from enum import Enum -from typing import List +from enum import Enum +from typing import List, Tuple, Dict +from functools import singledispatch +numel = len # ------------------------------------------------------------------------ # String buffer @@ -40,8 +41,7 @@ def exits_scope(): # Abstract class everything will derive from # All AST nodes will have an "emit" function that outputs formatted C code -class Emitter(ABC): - @abstractmethod +class Emitter(object): def emit(self): pass @@ -142,8 +142,83 @@ class Expr(Emitter): def emit(): pass +# Literals +class Literal(): + def emit(self): + emit(f"{self}") + +class I(Literal, int): + def __new__(cls, i: int): + return super(I, cls).__new__(cls, i) + +class F(Literal, float): + def __new__(self, f: float): + return super(F, self).__new__(cls, f) + +class S(Literal, str): + def __new__(self, s: str): + return super(S, self).__new__(cls, s) + +# Ident of symbol +class Ident(Expr): + def __init__(self, var): + self.name = var.name + self.var = var + + def emit(self): + emit(f"{self.name}") + +# Unary operators +class UnaryOp(Expr): + def __init__(self, x: Expr): + self.x = x + + def emit(self): + pass + +class Deref(UnaryOp): + def emit(self): + emit("*") + self.x.emit() + +class Negate(UnaryOp): + def emit(self): + emit("~") + self.x.emit() + +class Ref(UnaryOp): + def emit(self): + emit("&") + self.x.emit() + +class Inc(UnaryOp): + def __init__(self, x: Expr, pre=False): + self.x = x + self.pre = pre + + def emit(self): + if self.pre: + emit("++") + self.x.emit() + else: + self.x.emit() + emit("++") + +class Dec(UnaryOp): + def __init__(self, x: Expr, pre=False): + self.x = x + self.pre = pre + + def emit(self): + if self.pre: + emit("--") + self.x.emit() + else: + self.x.emit() + emit("--") + # Binary operators -class BinOp(Expr): +class BinaryOp(Expr): def __init__(self, left: Expr, right: Expr): self.l = left self.r = right @@ -152,55 +227,68 @@ class BinOp(Expr): pass # TODO: check types if they are vectorized and emit correct intrinsic -class Add(BinOp): +class Add(BinaryOp): def emit(self): self.l.emit() emit(f" + ") self.r.emit() -class Sub(BinOp): +class Sub(BinaryOp): def emit(self): self.l.emit() emit(f" - ") self.r.emit() -class Mul(BinOp): +class Mul(BinaryOp): def emit(self): self.l.emit() emit(f" * ") self.r.emit() -class Div(BinOp): +class Div(BinaryOp): def emit(self): self.l.emit() emit(f" / ") self.r.emit() -class Gt(BinOp): +class And(BinaryOp): + def emit(self): + self.l.emit() + emit(f" & ") + self.r.emit() + +class Xor(BinaryOp): + def emit(self): + self.l.emit() + emit(f" ^ ") + self.r.emit() + + +class GT(BinaryOp): def emit(self): self.l.emit() emit(f" > ") self.r.emit() -class Lt(BinOp): +class LT(BinaryOp): def emit(self): self.l.emit() emit(f" < ") self.r.emit() -class Ge(BinOp): +class GE(BinaryOp): def emit(self): self.l.emit() emit(f" >= ") self.r.emit() -class Le(BinOp): +class LE(BinaryOp): def emit(self): self.l.emit() emit(f" <= ") self.r.emit() -class Eq(BinOp): +class EQ(BinaryOp): def emit(self): self.l.emit() emit(f" == ") @@ -212,31 +300,31 @@ class Assign(Expr): self.lhs = lhs self.rhs = rhs -class Mv(Assign): +class Set(Assign): def emit(self): self.lhs.emit() emit(f" = ") self.rhs.emit() -class AddMv(Assign): +class AddSet(Assign): def emit(self): self.lhs.emit() emit(f" += ") self.rhs.emit() -class SubMv(Assign): +class SubSet(Assign): def emit(self): self.lhs.emit() emit(f" -= ") self.rhs.emit() -class MulMv(Assign): +class MulSet(Assign): def emit(self): self.lhs.emit() emit(f" *= ") self.rhs.emit() -class DivMv(Assign): +class DivSet(Assign): def emit(self): self.lhs.emit() emit(f" /= ") @@ -251,6 +339,48 @@ class Comma(Expr): emit(", ") self.expr[1].emit() +class Index(Expr): + def __init__(self, x: Expr, i: Expr): + self.x = x + self.i = i + + def emit(self): + self.x.emit() + emit("[") + self.i.emit() + emit("]") + +class Paren(Expr): + def __init__(self, x: Expr): + self.x = x + + def emit(self): + emit("(") + self.x.emit() + emit(")") + +class Call(Expr): + def __init__(self, func: Func, args: List[Param]): + self.func = func + self.args = args + + def emit(self): + emit(self.func.ident) + emit("(") + if numel(self.args) > 0: + self.args[0].emit() + for arg in self.args[1:]: + emit(", ") + arg.emit() + emit(")") + +def ExprList(*expr: Expr | List[Expr]): + if numel(expr) > 1: + x = expr[0] + return Comma(x, ExprList(*expr[1:])) + else: + return expr[0] + # Common assignments are kept globally # C statements @@ -270,29 +400,38 @@ class Block(Stmt): def emit(self): enter_scope() - for stmt in self.stmts: + for i, stmt in enumerate(self.stmts): stmt.emit() + if i < numel(self.stmts) - 1: + emitln() + exits_scope() - super(Block, self).emit() class For(Stmt): - def __init__(self, init: Expr, cond: Expr, step: Expr, body: Stmt): - self.init = init - self.cond = cond - self.step = step + def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt): + self.init = init if isinstance(init, Expr) else ExprList(*init) + self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) + self.step = step if isinstance(step, Expr) else ExprList(*step) self.body = body def emit(self): emit("for (") - self.init.emit() - emit(";") - self.cond.emit() - emit(";") - self.step.emit() - emit(")") - - self.body.emit() - super(For, self).emit() + if self.init is not None: + self.init.emit() + emit("; ") + if self.cond is not None: + self.cond.emit() + emit("; ") + if self.step is not None: + self.step.emit() + emit(") ") + + if isinstance(self.body, Block): + self.body.emit() + else: + enter_scope() + self.body.emit() + exits_scope() class Return(Stmt): def __init__(self, val: Expr): @@ -327,12 +466,12 @@ class Decl(Emitter): pass class Func(Decl): - def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = [], vars: List[Var | List[Var]] = [], body: List[Stmt] = []): + def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): self.ident = ident self.ret = ret - self.params = params - self.vars = vars - self.stmts = body + self.params = params if params is not None else [] + self.vars = vars if vars is not None else [] + self.stmts = body if body is not None else [] def emit(self): self.ret.emit() @@ -343,7 +482,7 @@ class Func(Decl): p.emittype() emit(" ") p.emitspec() - if i < len(self.params) - 1: + if i < numel(self.params) - 1: emit(", ") emit(")\n") @@ -366,17 +505,29 @@ class Func(Decl): emit(";") emitln() - emitln() - for stmt in self.stmts: + if numel(self.vars) > 0: + emitln() + + for stmt in self.stmts[:-1]: stmt.emit() + emitln() + if numel(self.stmts) > 0: + self.stmts[-1].emit() exits_scope() - def declare(self, var: Var, *vars: List[Var]): + def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: self.vars.append(var) + if numel(vars) == 0: + return Ident(var) + self.vars.extend(vars) - def instruct(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]): + idents = [Ident(var)] + idents += [Ident(v) for v in vars] + return idents + + def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]): def push(n): if isinstance(n, Stmt): self.stmts.append(n) @@ -384,11 +535,14 @@ class Func(Decl): self.stmts.append(StmtExpr(n)) else: raise TypeError("unrecognized type for function") - push(stmt) for arg in args: push(arg) + def variables(self, *idents: List[str]) -> List[Expr]: + vars = {v.name : v for v in self.vars + self.params} + return [Ident(vars[ident]) for ident in idents] + class Var(Decl): def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto): self.name = name @@ -409,7 +563,190 @@ class Var(Decl): class Param(Var): def __init__(self, type: Type, name: str): - return super(Param, self).__init__(type, name, mem.Auto) + return super(Param, self).__init__(type, name, Mem.Auto) + +def Params(*ps: List[Tuple(Type, str)]) -> List[Param]: + return [Param(p[0], p[1]) for p in ps] + +def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: + return [Var(p[0], p[1]) for p in ps] + +# ------------------------------------------------------------------------ +# AST modification/production functions + +def Swap(x: Var, y: Var, tmp: Var) -> Stmt: + return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) + +@singledispatch +def VarsUsed(s: object) -> List[Vars]: + raise TypeError(f"{type(s)} not supported by VarsUsed operation") + +@VarsUsed.register +def _(op: UnaryOp): + return VarsUsed(op.x) + +@VarsUsed.register +def _(sym: Ident): + return [sym.var] + +@VarsUsed.register +def _(s: Assign): + vars = [] + vars.extend(VarsUsed(s.lhs)) + vars.extend(VarsUsed(s.rhs)) + return vars + +@VarsUsed.register +def _(lit: Literal): + return [] + +@VarsUsed.register +def _(comma: Comma): + vars = [] + vars.extend(VarsUsed(comma.expr[0])) + vars.extend(VarsUsed(comma.expr[1])) + return vars + +@VarsUsed.register +def _(op: BinaryOp): + vars = [] + vars.extend(VarsUsed(op.l)) + vars.extend(VarsUsed(op.r)) + return vars + +@VarsUsed.register +def _(s: Empty): + return [] + +@VarsUsed.register +def _(blk: Block): + vars = [] + for stmt in blk.stmts: + vars.extend(VarsUsed(stmt)) + return vars + +@VarsUsed.register +def _(loop: For): + vars = [] + vars.extend(VarsUsed(loop.init)) + vars.extend(VarsUsed(loop.cond)) + vars.extend(VarsUsed(loop.step)) + vars.extend(VarsUsed(loop.body)) + return vars + +@VarsUsed.register +def _(ret: Return): + return VarsUsed(ret.val) + +@VarsUsed.register +def _(x: StmtExpr): + return VarsUsed(x.x) + +@singledispatch +def Repeat(x: Expr, times: int) -> Expr | List[Expr]: + raise TypeError(f"{type(x)} not supported by Repeat operation") + +@Repeat.register +def _(x: Inc, times: int): + return AddSet(x.x, I(times)) + +@Repeat.register +def _(x: Dec, times: int): + return DecSet(x.x, I(times)) + +@Repeat.register +def _(x: Comma, times: int): + return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) + +# TODO: Parameterize the variables that should not be expanded +@singledispatch +def Step(x: Expr, i: int) -> Expr: + raise TypeError(f"{type(x)} not supported by Step operation") + +@Step.register +def _(x: Comma, i: int): + return Comma(Step(x.expr[0], i), Step(x.expr[1], i)) + +@Step.register +def _(x: Assign, i: int): + return type(x)(Step(x.lhs, i), Step(x.rhs, i)) + +@Step.register +def _(x: Ident, i: int): + return x + +@Step.register +def _(x: Deref, i: int): + return Index(x.x, I(i)) + +def Expand(s: StmtExpr, times: int) -> Block: + if not isinstance(s, StmtExpr): + raise TypeError(f"{type(x)} not supported by Expand operation") + + return Block([StmtExpr(Step(s.x, i)) for i in range(times)]) + +def EvenTo(x: Var, n: int) -> Var: + return And(x, Negate(I(n-1))) + +def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call): + # TODO: More sophisticated computation for length of loop + if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): + raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") + + # pull off needed features of the loop + it = loop.init.lhs.var + vars = set(VarsUsed(loop.body)) + + params = [v for v in vars if type(v) == Param] + stacks = [v for v in vars if type(v) == Var] + + # TODO: More sophisticated type checking + n = loop.cond.r.var + if (type(n) != Param): + raise TypeError(f"{type(n)} not implemented yet") + params = [n] + params + + kernel = Func(f"{name}{times}", Int, params, stacks) + body = loop.body + + kernel.execute( + Set(n, EvenTo(n, times)), + For(Set(it, I(0)), LT(it, n), Repeat(loop.step, times), + body=Expand(loop.body, times) + ), + Return(n) + ) + + loop.init = None + return loop, kernel, Call(kernel, params) # ------------------------------------------------------------------------ -# AST modification functions +# Point of testing + +if __name__ == "__main__": + Rot = Func("blas·swap", Void, + params = Params( + (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") + ), + vars = Vars( + (Int, "i"), (Float64, "tmp") + ) + ) + # Rot.declare( ... ) + + # TODO: Increase ergonomics here... + len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp") + + loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)], + body = Swap(Deref(x), Deref(y), tmp) + ) + + rem, kernel, it = Unroll(loop, 8, "swap_kernel") + kernel.emit() + emitln(2) + + Rot.execute(Set(i, it), rem) + + Rot.emit() + + print(buffer) -- cgit v1.2.1