aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-10 20:25:46 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-10 20:25:46 -0700
commitacf54a4b990634e23a88b04363923c4b108a0304 (patch)
treecaeba25823b98abaa6eaa9b33a56ab9eacd76672
parent9828864af9fb9d39d36e4f3811ce8c3283e8433d (diff)
extract compute kernel prototype
-rw-r--r--lib/c.py431
1 files changed, 384 insertions, 47 deletions
diff --git a/lib/c.py b/lib/c.py
index 4f0ae29..54b8e14 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -4,10 +4,11 @@ from __future__ import annotations
import os
import sys
-from abc import ABC, abstractmethod
-from enum import Enum
-from typing import List
+from enum import Enum
+from typing import List, Tuple, Dict
+from functools import singledispatch
+numel = len
# ------------------------------------------------------------------------
# String buffer
@@ -40,8 +41,7 @@ def exits_scope():
# Abstract class everything will derive from
# All AST nodes will have an "emit" function that outputs formatted C code
-class Emitter(ABC):
- @abstractmethod
+class Emitter(object):
def emit(self):
pass
@@ -142,8 +142,83 @@ class Expr(Emitter):
def emit():
pass
+# Literals
+class Literal():
+ def emit(self):
+ emit(f"{self}")
+
+class I(Literal, int):
+ def __new__(cls, i: int):
+ return super(I, cls).__new__(cls, i)
+
+class F(Literal, float):
+ def __new__(self, f: float):
+ return super(F, self).__new__(cls, f)
+
+class S(Literal, str):
+ def __new__(self, s: str):
+ return super(S, self).__new__(cls, s)
+
+# Ident of symbol
+class Ident(Expr):
+ def __init__(self, var):
+ self.name = var.name
+ self.var = var
+
+ def emit(self):
+ emit(f"{self.name}")
+
+# Unary operators
+class UnaryOp(Expr):
+ def __init__(self, x: Expr):
+ self.x = x
+
+ def emit(self):
+ pass
+
+class Deref(UnaryOp):
+ def emit(self):
+ emit("*")
+ self.x.emit()
+
+class Negate(UnaryOp):
+ def emit(self):
+ emit("~")
+ self.x.emit()
+
+class Ref(UnaryOp):
+ def emit(self):
+ emit("&")
+ self.x.emit()
+
+class Inc(UnaryOp):
+ def __init__(self, x: Expr, pre=False):
+ self.x = x
+ self.pre = pre
+
+ def emit(self):
+ if self.pre:
+ emit("++")
+ self.x.emit()
+ else:
+ self.x.emit()
+ emit("++")
+
+class Dec(UnaryOp):
+ def __init__(self, x: Expr, pre=False):
+ self.x = x
+ self.pre = pre
+
+ def emit(self):
+ if self.pre:
+ emit("--")
+ self.x.emit()
+ else:
+ self.x.emit()
+ emit("--")
+
# Binary operators
-class BinOp(Expr):
+class BinaryOp(Expr):
def __init__(self, left: Expr, right: Expr):
self.l = left
self.r = right
@@ -152,55 +227,68 @@ class BinOp(Expr):
pass
# TODO: check types if they are vectorized and emit correct intrinsic
-class Add(BinOp):
+class Add(BinaryOp):
def emit(self):
self.l.emit()
emit(f" + ")
self.r.emit()
-class Sub(BinOp):
+class Sub(BinaryOp):
def emit(self):
self.l.emit()
emit(f" - ")
self.r.emit()
-class Mul(BinOp):
+class Mul(BinaryOp):
def emit(self):
self.l.emit()
emit(f" * ")
self.r.emit()
-class Div(BinOp):
+class Div(BinaryOp):
def emit(self):
self.l.emit()
emit(f" / ")
self.r.emit()
-class Gt(BinOp):
+class And(BinaryOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" & ")
+ self.r.emit()
+
+class Xor(BinaryOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" ^ ")
+ self.r.emit()
+
+
+class GT(BinaryOp):
def emit(self):
self.l.emit()
emit(f" > ")
self.r.emit()
-class Lt(BinOp):
+class LT(BinaryOp):
def emit(self):
self.l.emit()
emit(f" < ")
self.r.emit()
-class Ge(BinOp):
+class GE(BinaryOp):
def emit(self):
self.l.emit()
emit(f" >= ")
self.r.emit()
-class Le(BinOp):
+class LE(BinaryOp):
def emit(self):
self.l.emit()
emit(f" <= ")
self.r.emit()
-class Eq(BinOp):
+class EQ(BinaryOp):
def emit(self):
self.l.emit()
emit(f" == ")
@@ -212,31 +300,31 @@ class Assign(Expr):
self.lhs = lhs
self.rhs = rhs
-class Mv(Assign):
+class Set(Assign):
def emit(self):
self.lhs.emit()
emit(f" = ")
self.rhs.emit()
-class AddMv(Assign):
+class AddSet(Assign):
def emit(self):
self.lhs.emit()
emit(f" += ")
self.rhs.emit()
-class SubMv(Assign):
+class SubSet(Assign):
def emit(self):
self.lhs.emit()
emit(f" -= ")
self.rhs.emit()
-class MulMv(Assign):
+class MulSet(Assign):
def emit(self):
self.lhs.emit()
emit(f" *= ")
self.rhs.emit()
-class DivMv(Assign):
+class DivSet(Assign):
def emit(self):
self.lhs.emit()
emit(f" /= ")
@@ -251,6 +339,48 @@ class Comma(Expr):
emit(", ")
self.expr[1].emit()
+class Index(Expr):
+ def __init__(self, x: Expr, i: Expr):
+ self.x = x
+ self.i = i
+
+ def emit(self):
+ self.x.emit()
+ emit("[")
+ self.i.emit()
+ emit("]")
+
+class Paren(Expr):
+ def __init__(self, x: Expr):
+ self.x = x
+
+ def emit(self):
+ emit("(")
+ self.x.emit()
+ emit(")")
+
+class Call(Expr):
+ def __init__(self, func: Func, args: List[Param]):
+ self.func = func
+ self.args = args
+
+ def emit(self):
+ emit(self.func.ident)
+ emit("(")
+ if numel(self.args) > 0:
+ self.args[0].emit()
+ for arg in self.args[1:]:
+ emit(", ")
+ arg.emit()
+ emit(")")
+
+def ExprList(*expr: Expr | List[Expr]):
+ if numel(expr) > 1:
+ x = expr[0]
+ return Comma(x, ExprList(*expr[1:]))
+ else:
+ return expr[0]
+
# Common assignments are kept globally
# C statements
@@ -270,29 +400,38 @@ class Block(Stmt):
def emit(self):
enter_scope()
- for stmt in self.stmts:
+ for i, stmt in enumerate(self.stmts):
stmt.emit()
+ if i < numel(self.stmts) - 1:
+ emitln()
+
exits_scope()
- super(Block, self).emit()
class For(Stmt):
- def __init__(self, init: Expr, cond: Expr, step: Expr, body: Stmt):
- self.init = init
- self.cond = cond
- self.step = step
+ def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt):
+ self.init = init if isinstance(init, Expr) else ExprList(*init)
+ self.cond = cond if isinstance(cond, Expr) else ExprList(*cond)
+ self.step = step if isinstance(step, Expr) else ExprList(*step)
self.body = body
def emit(self):
emit("for (")
- self.init.emit()
- emit(";")
- self.cond.emit()
- emit(";")
- self.step.emit()
- emit(")")
-
- self.body.emit()
- super(For, self).emit()
+ if self.init is not None:
+ self.init.emit()
+ emit("; ")
+ if self.cond is not None:
+ self.cond.emit()
+ emit("; ")
+ if self.step is not None:
+ self.step.emit()
+ emit(") ")
+
+ if isinstance(self.body, Block):
+ self.body.emit()
+ else:
+ enter_scope()
+ self.body.emit()
+ exits_scope()
class Return(Stmt):
def __init__(self, val: Expr):
@@ -327,12 +466,12 @@ class Decl(Emitter):
pass
class Func(Decl):
- def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = [], vars: List[Var | List[Var]] = [], body: List[Stmt] = []):
+ def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None):
self.ident = ident
self.ret = ret
- self.params = params
- self.vars = vars
- self.stmts = body
+ self.params = params if params is not None else []
+ self.vars = vars if vars is not None else []
+ self.stmts = body if body is not None else []
def emit(self):
self.ret.emit()
@@ -343,7 +482,7 @@ class Func(Decl):
p.emittype()
emit(" ")
p.emitspec()
- if i < len(self.params) - 1:
+ if i < numel(self.params) - 1:
emit(", ")
emit(")\n")
@@ -366,17 +505,29 @@ class Func(Decl):
emit(";")
emitln()
- emitln()
- for stmt in self.stmts:
+ if numel(self.vars) > 0:
+ emitln()
+
+ for stmt in self.stmts[:-1]:
stmt.emit()
+ emitln()
+ if numel(self.stmts) > 0:
+ self.stmts[-1].emit()
exits_scope()
- def declare(self, var: Var, *vars: List[Var]):
+ def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]:
self.vars.append(var)
+ if numel(vars) == 0:
+ return Ident(var)
+
self.vars.extend(vars)
- def instruct(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]):
+ idents = [Ident(var)]
+ idents += [Ident(v) for v in vars]
+ return idents
+
+ def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]):
def push(n):
if isinstance(n, Stmt):
self.stmts.append(n)
@@ -384,11 +535,14 @@ class Func(Decl):
self.stmts.append(StmtExpr(n))
else:
raise TypeError("unrecognized type for function")
-
push(stmt)
for arg in args:
push(arg)
+ def variables(self, *idents: List[str]) -> List[Expr]:
+ vars = {v.name : v for v in self.vars + self.params}
+ return [Ident(vars[ident]) for ident in idents]
+
class Var(Decl):
def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto):
self.name = name
@@ -409,7 +563,190 @@ class Var(Decl):
class Param(Var):
def __init__(self, type: Type, name: str):
- return super(Param, self).__init__(type, name, mem.Auto)
+ return super(Param, self).__init__(type, name, Mem.Auto)
+
+def Params(*ps: List[Tuple(Type, str)]) -> List[Param]:
+ return [Param(p[0], p[1]) for p in ps]
+
+def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]:
+ return [Var(p[0], p[1]) for p in ps]
+
+# ------------------------------------------------------------------------
+# AST modification/production functions
+
+def Swap(x: Var, y: Var, tmp: Var) -> Stmt:
+ return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp)))
+
+@singledispatch
+def VarsUsed(s: object) -> List[Vars]:
+ raise TypeError(f"{type(s)} not supported by VarsUsed operation")
+
+@VarsUsed.register
+def _(op: UnaryOp):
+ return VarsUsed(op.x)
+
+@VarsUsed.register
+def _(sym: Ident):
+ return [sym.var]
+
+@VarsUsed.register
+def _(s: Assign):
+ vars = []
+ vars.extend(VarsUsed(s.lhs))
+ vars.extend(VarsUsed(s.rhs))
+ return vars
+
+@VarsUsed.register
+def _(lit: Literal):
+ return []
+
+@VarsUsed.register
+def _(comma: Comma):
+ vars = []
+ vars.extend(VarsUsed(comma.expr[0]))
+ vars.extend(VarsUsed(comma.expr[1]))
+ return vars
+
+@VarsUsed.register
+def _(op: BinaryOp):
+ vars = []
+ vars.extend(VarsUsed(op.l))
+ vars.extend(VarsUsed(op.r))
+ return vars
+
+@VarsUsed.register
+def _(s: Empty):
+ return []
+
+@VarsUsed.register
+def _(blk: Block):
+ vars = []
+ for stmt in blk.stmts:
+ vars.extend(VarsUsed(stmt))
+ return vars
+
+@VarsUsed.register
+def _(loop: For):
+ vars = []
+ vars.extend(VarsUsed(loop.init))
+ vars.extend(VarsUsed(loop.cond))
+ vars.extend(VarsUsed(loop.step))
+ vars.extend(VarsUsed(loop.body))
+ return vars
+
+@VarsUsed.register
+def _(ret: Return):
+ return VarsUsed(ret.val)
+
+@VarsUsed.register
+def _(x: StmtExpr):
+ return VarsUsed(x.x)
+
+@singledispatch
+def Repeat(x: Expr, times: int) -> Expr | List[Expr]:
+ raise TypeError(f"{type(x)} not supported by Repeat operation")
+
+@Repeat.register
+def _(x: Inc, times: int):
+ return AddSet(x.x, I(times))
+
+@Repeat.register
+def _(x: Dec, times: int):
+ return DecSet(x.x, I(times))
+
+@Repeat.register
+def _(x: Comma, times: int):
+ return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times))
+
+# TODO: Parameterize the variables that should not be expanded
+@singledispatch
+def Step(x: Expr, i: int) -> Expr:
+ raise TypeError(f"{type(x)} not supported by Step operation")
+
+@Step.register
+def _(x: Comma, i: int):
+ return Comma(Step(x.expr[0], i), Step(x.expr[1], i))
+
+@Step.register
+def _(x: Assign, i: int):
+ return type(x)(Step(x.lhs, i), Step(x.rhs, i))
+
+@Step.register
+def _(x: Ident, i: int):
+ return x
+
+@Step.register
+def _(x: Deref, i: int):
+ return Index(x.x, I(i))
+
+def Expand(s: StmtExpr, times: int) -> Block:
+ if not isinstance(s, StmtExpr):
+ raise TypeError(f"{type(x)} not supported by Expand operation")
+
+ return Block([StmtExpr(Step(s.x, i)) for i in range(times)])
+
+def EvenTo(x: Var, n: int) -> Var:
+ return And(x, Negate(I(n-1)))
+
+def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call):
+ # TODO: More sophisticated computation for length of loop
+ if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT):
+ raise TypeError(f"{type(loop.cond)} not supported in loop unrolling")
+
+ # pull off needed features of the loop
+ it = loop.init.lhs.var
+ vars = set(VarsUsed(loop.body))
+
+ params = [v for v in vars if type(v) == Param]
+ stacks = [v for v in vars if type(v) == Var]
+
+ # TODO: More sophisticated type checking
+ n = loop.cond.r.var
+ if (type(n) != Param):
+ raise TypeError(f"{type(n)} not implemented yet")
+ params = [n] + params
+
+ kernel = Func(f"{name}{times}", Int, params, stacks)
+ body = loop.body
+
+ kernel.execute(
+ Set(n, EvenTo(n, times)),
+ For(Set(it, I(0)), LT(it, n), Repeat(loop.step, times),
+ body=Expand(loop.body, times)
+ ),
+ Return(n)
+ )
+
+ loop.init = None
+ return loop, kernel, Call(kernel, params)
# ------------------------------------------------------------------------
-# AST modification functions
+# Point of testing
+
+if __name__ == "__main__":
+ Rot = Func("blas·swap", Void,
+ params = Params(
+ (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
+ ),
+ vars = Vars(
+ (Int, "i"), (Float64, "tmp")
+ )
+ )
+ # Rot.declare( ... )
+
+ # TODO: Increase ergonomics here...
+ len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)],
+ body = Swap(Deref(x), Deref(y), tmp)
+ )
+
+ rem, kernel, it = Unroll(loop, 8, "swap_kernel")
+ kernel.emit()
+ emitln(2)
+
+ Rot.execute(Set(i, it), rem)
+
+ Rot.emit()
+
+ print(buffer)