#! /bin/python3 from __future__ import annotations import os import sys from enum import Enum from typing import List, Tuple, Dict from functools import singledispatch numel = len # ------------------------------------------------------------------------ # String buffer level = 0 buffer = "" def emit(s: str): global buffer buffer += s def emitln(n=1): global buffer buffer += "\n"*n + (" " * level) def enter_scope(): global level emit("{") level += 1 emitln() def exits_scope(): global level level -= 1 emitln() emit("}") def emitheader(): emit("#include \n") emit("#include \n") emit("#include \n") emitln() # ------------------------------------------------------------------------ # Simple C AST # TODO: Type checking # Abstract class everything will derive from # All AST nodes will have an "emit" function that outputs formatted C code class Emitter(object): def emit(self): pass # ------------------------------------------ # Representation of a C type class TypeKind(Enum): Void = "void" Error = "error" Int = "int" Int32 = "int32" Int64 = "int64" # vectorized variants Int32x4 = "__mm128i" Int32x8 = "__mm256i" Int64x2 = "__mm128i" Int64x4 = "__mm256i" Float32 = "float" Float64 = "double" # vectorized variants Float32x4 = "__m128" Float32x8 = "__mm256" Float64x2 = "__mm128d" Float64x4 = "__mm256d" Pointer = "pointer" Struct = "struct" Enum = "enum" Union = "union" class Type(Emitter): def emit(self): pass def emitspec(self, var): pass class Base(Type): def __init__(self, name: str): self.name = name def emit(self): emit(self.name) def emitspec(self, ident): emit(f"{ident}") # Machine primitive types Void = Base("void") Error = Base("error") Int = Base("int") Int8 = Base("int8") Int16 = Base("int16") Int32 = Base("int32") Int64 = Base("int64") Int32x4 = Base("__mm128i") Int32x8 = Base("__mm256i") Int64x2 = Base("__mm128i") Int64x4 = Base("__mm256i") Float32 = Base("float") Float64 = Base("double") Float32x4 = Base("__m128") Float32x8 = Base("__mm256") Float64x2 = Base("__m128d") Float64x4 = Base("__mm256d") # TODO: Make this ARCH dependent BitDepth= { Void: 8, Int: 32, Int32: 32, Int64: 64, Float32: 64, Float64: 64, } class SIMD(Enum): SSE = "sse" SSE2 = "sse2" AVX = "avx" AVX2 = "avx2" FMA3 = "fma3" AVX5 = "avx512" RegisterSize = { SIMD.SSE: 128, SIMD.SSE2: 128, SIMD.AVX: 256, SIMD.AVX2: 256, SIMD.FMA3: 256, SIMD.AVX5: 512, } SIMDSupport = { SIMD.SSE: set([Float32]), SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), } # TODO: Operation lookup tables... class Ptr(Type): def __init__(self, to: Type): self.to = to def emit(self): self.to.emit() def emitspec(self, ident): emit("*") self.to.emitspec(ident) class Array(Type): def __init__(self, base: Type, len: int): self.base = base self.len = len def emit(self): self.base.emit() def emitspec(self, ident): self.base.emitspec(ident) emit(f"[{self.len}]") # TODO: Typedefs... # ------------------------------------------ # C expressions class Expr(Emitter): def emit(): pass # Literals class Literal(): def emit(self): emit(f"{self}") class I(Literal, int): def __new__(cls, i: int): return super(I, cls).__new__(cls, i) class F(Literal, float): def __new__(self, f: float): return super(F, self).__new__(cls, f) class S(Literal, str): def __new__(self, s: str): return super(S, self).__new__(cls, s) # Ident of symbol class Ident(Expr): def __init__(self, var): self.name = var.name self.var = var def emit(self): emit(f"{self.name}") # Unary operators class UnaryOp(Expr): def __init__(self, x: Expr): self.x = x def emit(self): pass class Deref(UnaryOp): def emit(self): emit("*") self.x.emit() class Negate(UnaryOp): def emit(self): emit("~") self.x.emit() class Ref(UnaryOp): def emit(self): emit("&") self.x.emit() class Inc(UnaryOp): def __init__(self, x: Expr, pre=False): self.x = x self.pre = pre def emit(self): if self.pre: emit("++") self.x.emit() else: self.x.emit() emit("++") class Dec(UnaryOp): def __init__(self, x: Expr, pre=False): self.x = x self.pre = pre def emit(self): if self.pre: emit("--") self.x.emit() else: self.x.emit() emit("--") # Binary operators class BinaryOp(Expr): def __init__(self, left: Expr, right: Expr): self.l = left self.r = right def emit(self): pass # TODO: check types if they are vectorized and emit correct intrinsic class Add(BinaryOp): def emit(self): self.l.emit() emit(f" + ") self.r.emit() class Sub(BinaryOp): def emit(self): self.l.emit() emit(f" - ") self.r.emit() class Mul(BinaryOp): def emit(self): self.l.emit() emit(f" * ") self.r.emit() class Div(BinaryOp): def emit(self): self.l.emit() emit(f" / ") self.r.emit() class And(BinaryOp): def emit(self): self.l.emit() emit(f" & ") self.r.emit() class Xor(BinaryOp): def emit(self): self.l.emit() emit(f" ^ ") self.r.emit() class GT(BinaryOp): def emit(self): self.l.emit() emit(f" > ") self.r.emit() class LT(BinaryOp): def emit(self): self.l.emit() emit(f" < ") self.r.emit() class GE(BinaryOp): def emit(self): self.l.emit() emit(f" >= ") self.r.emit() class LE(BinaryOp): def emit(self): self.l.emit() emit(f" <= ") self.r.emit() class EQ(BinaryOp): def emit(self): self.l.emit() emit(f" == ") self.r.emit() # Assignment (stores) class Assign(Expr): def __init__(self, lhs: Expr, rhs: Expr): self.lhs = lhs self.rhs = rhs class Set(Assign): def emit(self): self.lhs.emit() emit(f" = ") self.rhs.emit() class AddSet(Assign): def emit(self): self.lhs.emit() emit(f" += ") self.rhs.emit() class SubSet(Assign): def emit(self): self.lhs.emit() emit(f" -= ") self.rhs.emit() class MulSet(Assign): def emit(self): self.lhs.emit() emit(f" *= ") self.rhs.emit() class DivSet(Assign): def emit(self): self.lhs.emit() emit(f" /= ") self.rhs.emit() class Comma(Expr): def __init__(self, x: Expr, next: Expr): self.expr = (x, next) def emit(self): self.expr[0].emit() emit(", ") self.expr[1].emit() class Index(Expr): def __init__(self, x: Expr, i: Expr): self.x = x self.i = i def emit(self): self.x.emit() emit("[") self.i.emit() emit("]") class Paren(Expr): def __init__(self, x: Expr): self.x = x def emit(self): emit("(") self.x.emit() emit(")") class Call(Expr): def __init__(self, func: Func, args: List[Param]): self.func = func self.args = args def emit(self): emit(self.func.name) emit("(") if numel(self.args) > 0: self.args[0].emit() for arg in self.args[1:]: emit(", ") arg.emit() emit(")") def ExprList(*expr: Expr | List[Expr]): if numel(expr) > 1: x = expr[0] return Comma(x, ExprList(*expr[1:])) else: return expr[0] # Common assignments are kept globally # C statements class Stmt(Emitter): def emit(self): emit(";") class Empty(Stmt): def __init__(self): pass def emit(self): super(Empty, self).emit() class Block(Stmt): def __init__(self, stmts: List[Stmt]): self.stmts = stmts def emit(self): enter_scope() for i, stmt in enumerate(self.stmts): stmt.emit() if i < numel(self.stmts) - 1: emitln() exits_scope() class For(Stmt): def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None): self.init = init if isinstance(init, Expr) else ExprList(*init) self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) self.step = step if isinstance(step, Expr) else ExprList(*step) self.body = body if body is not None else Empty() def emit(self): emit("for (") if self.init is not None: self.init.emit() emit("; ") if self.cond is not None: self.cond.emit() emit("; ") if self.step is not None: self.step.emit() emit(") ") if isinstance(self.body, Block): self.body.emit() else: enter_scope() self.body.emit() exits_scope() class Return(Stmt): def __init__(self, val: Expr): self.val = val def emit(self): emitln() emit("return ") self.val.emit() super(Return, self).emit() class StmtExpr(Stmt): def __init__(self, x: Expr): self.x = x def emit(self): self.x.emit() super(StmtExpr, self).emit() class Mem(Enum): Auto = "" Static = "static" Register = "register" Typedef = "typedef" External = "extern" class Decl(Emitter): def __init__(self): pass def emit(self): pass class Func(Decl): def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): self.name = name self.ret = ret self.params = params if params is not None else [] self.vars = vars if vars is not None else [] self.stmts = body if body is not None else [] def emit(self): self.ret.emit() emitln() emit(self.name) emit("(") for i, p in enumerate(self.params): p.emittype() emit(" ") p.emitspec() if i < numel(self.params) - 1: emit(", ") emit(")\n") enter_scope() for var in self.vars: if isinstance(var, list): v = var[0] v.emittype() emit(" ") v.emitspec() for v in var[1:]: emit(", ") v.emitspec() else: var.emittype() emit(" ") var.emitspec() emit(";") emitln() if numel(self.vars) > 0: emitln() for stmt in self.stmts[:-1]: stmt.emit() emitln() if numel(self.stmts) > 0: self.stmts[-1].emit() exits_scope() def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: self.vars.append(var) if numel(vars) == 0: return Ident(var) self.vars.extend(vars) idents = [Ident(var)] idents += [Ident(v) for v in vars] return idents def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]): def push(n): if isinstance(n, Stmt): self.stmts.append(n) elif isinstance(n, Expr): self.stmts.append(StmtExpr(n)) else: raise TypeError("unrecognized type for function") push(stmt) for arg in args: push(arg) def variables(self, *idents: List[str]) -> List[Expr]: vars = {v.name : v for v in self.vars + self.params} return [Ident(vars[ident]) for ident in idents] class Var(Decl): def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto): self.name = name self.type = type self.storage = storage def emit(self): emit(f"{self.name}") def emittype(self): if self.storage != Mem.Auto: emit(self.storage.value) emit(" ") self.type.emit() def emitspec(self): self.type.emitspec(self.name) class Param(Var): def __init__(self, type: Type, name: str): return super(Param, self).__init__(type, name, Mem.Auto) def Params(*ps: List[Tuple(Type, str)]) -> List[Param]: return [Param(p[0], p[1]) for p in ps] def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: return [Var(p[0], p[1]) for p in ps] # ------------------------------------------------------------------------ # AST modification/production functions def Swap(x: Var, y: Var, tmp: Var) -> Stmt: return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) def IsLoop(s: Stmt) -> bool: return type(s) == For # TODO: Generalize to memory addresses! # This will allow better vectorization @singledispatch def VarsUsed(s: object) -> List[Vars]: raise TypeError(f"{type(s)} not supported by VarsUsed operation") @VarsUsed.register def _(op: UnaryOp): return VarsUsed(op.x) @VarsUsed.register def _(sym: Ident): return [sym.var] @VarsUsed.register def _(s: Assign): vars = [] vars.extend(VarsUsed(s.lhs)) vars.extend(VarsUsed(s.rhs)) return vars @VarsUsed.register def _(lit: Literal): return [] @VarsUsed.register def _(i: Index): vars = [] vars.extend(VarsUsed(i.x)) vars.extend(VarsUsed(i.i)) return vars @VarsUsed.register def _(comma: Comma): vars = [] vars.extend(VarsUsed(comma.expr[0])) vars.extend(VarsUsed(comma.expr[1])) return vars @VarsUsed.register def _(op: BinaryOp): vars = [] vars.extend(VarsUsed(op.l)) vars.extend(VarsUsed(op.r)) return vars @VarsUsed.register def _(s: Empty): return [] @VarsUsed.register def _(blk: Block): vars = [] for stmt in blk.stmts: vars.extend(VarsUsed(stmt)) return vars @VarsUsed.register def _(loop: For): vars = [] vars.extend(VarsUsed(loop.init)) vars.extend(VarsUsed(loop.cond)) vars.extend(VarsUsed(loop.step)) vars.extend(VarsUsed(loop.body)) return vars @VarsUsed.register def _(ret: Return): return VarsUsed(ret.val) @VarsUsed.register def _(x: StmtExpr): return VarsUsed(x.x) @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: raise TypeError(f"{type(x)} not supported by Repeat operation") @Repeat.register def _(x: Inc, times: int): return AddSet(x.x, I(times)) @Repeat.register def _(x: Dec, times: int): return DecSet(x.x, I(times)) @Repeat.register def _(x: Comma, times: int): return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) # TODO: Parameterize the variables that should not be expanded @singledispatch def Step(x: Expr, i: int) -> Expr: raise TypeError(f"{type(x)} not supported by Step operation") @Step.register def _(x: Comma, i: int): return Comma(Step(x.expr[0], i), Step(x.expr[1], i)) @Step.register def _(x: Assign, i: int): return type(x)(Step(x.lhs, i), Step(x.rhs, i)) @Step.register def _(x: Ident, i: int): return x @Step.register def _(x: Deref, i: int): return Index(x.x, I(i)) @Step.register def _(ix: Index, i: int): return Index(ix.x, I(ix.i + i)) def Expand(s: StmtExpr, times: int) -> Block: if not isinstance(s, StmtExpr): raise TypeError(f"{type(x)} not supported by Expand operation") return Block([StmtExpr(Step(s.x, i)) for i in range(times)]) def EvenTo(x: Var, n: int) -> Var: return And(x, Negate(I(n-1))) def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call): # TODO: More sophisticated computation for length of loop if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") # pull off needed features of the loop it = loop.init.lhs.var vars = set(VarsUsed(loop.body)) params = [v for v in vars if type(v) == Param] stacks = [v for v in vars if type(v) == Var] # TODO: More sophisticated type checking n = loop.cond.r.var if (type(n) != Param): raise TypeError(f"{type(n)} not implemented yet") params = [n] + params kernel = Func(f"{name}_kernel{times}", Int, params, stacks) body = loop.body kernel.execute( Set(n, EvenTo(n, times)), For(Set(it, I(0)), LT(it, n), Repeat(loop.step, times), body=Expand(loop.body, times) ), Return(n) ) loop.init = None return loop, kernel, Call(kernel, params) # Replaces all vectorizable loops inside Func with vectorized variants # Returns the new vectorized function (tagged by the SIMD chosen) def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func: if isa != SIMD.AVX2: raise ValueError(f"ISA '{isa}' not currently implemented") vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params, func.vars) for stmt in func.stmts: if IsLoop(stmt): loop = For(stmt.init, stmt.cond, stmt.step) vfunc.execute(loop) else: vfunc.execute(stmt) return vfunc # ------------------------------------------------------------------------ # Point of testing if __name__ == "__main__": emitheader() Rot = Func("blas·swap", Void, Params( (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") ), Vars( (Int, "i"), (Float64, "tmp") ) ) # could also declare like # Rot.declare( ... ) # TODO: Increase ergonomics here... len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp") loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)], body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) ) rem, kernel, call = Unroll(loop, 8, "swap") kernel.emit() emitln(2) avx256kernel = Vectorize(kernel, [], SIMD.AVX2) avx256kernel.emit() emitln(2) Rot.execute(Set(i, call), rem) Rot.emit() print(buffer)