#! /bin/python3 from __future__ import annotations import os import sys import sympy from enum import Enum from typing import Callable, List, Tuple, Dict from functools import singledispatch numel = len # ------------------------------------------------------------------------ # String buffer level = 0 buffer = "" def emit(s: str): global buffer buffer += s def emitln(n=1): global buffer buffer += "\n"*n + (" " * level) def enter_scope(): global level emit("{") level += 1 emitln() def exits_scope(): global level level -= 1 emitln() emit("}") def emitheader(): emit("#include \n") emit("#include \n") emit("#include \n") emitln() # ------------------------------------------------------------------------ # Simple C AST # TODO: Type checking # Abstract class everything will derive from # All AST nodes will have an "emit" function that outputs formatted C code class Emitter(object): def emit(self): pass # ------------------------------------------ # Representation of a C type class Type(Emitter): def emit(self): pass def emitspec(self, var): pass class Base(Type): def __init__(self, name: str): self.name = name def __hash__(self): return hash(self.name) def __eq__(self, other): return type(self) == type(other) and self.name == other.name def __str__(self): return self.name def emit(self): emit(self.name) def emitspec(self, ident): emit(f"{ident}") # TODO: Operation lookup tables... class Ptr(Type): def __init__(self, to: Type): self.to = to def __str__(self): return f"*{self.to}" def __eq__(self, other): return type(self) == type(other) and self.to == other.to def __hash__(self): return hash((hash(self.to), "*")) def emit(self): self.to.emit() def emitspec(self, ident): emit("*") self.to.emitspec(ident) class Array(Type): def __init__(self, base: Type, len: int): self.base = base self.len = len def __str__(self): return f"[{self.len}]{self.base}" def __eq__(self, other): return type(self) == type(other) and self.base == other.base def __hash__(self): return hash((hash(self.base), self.len)) def emit(self): self.base.emit() def emitspec(self, ident): self.base.emitspec(ident) emit(f"[{self.len}]") # Machine primitive types Void = Base("void") Error = Base("error") Byte = Base("byte") String = Ptr(Byte) Int = Base("int") Int8 = Base("int8") Int16 = Base("int16") Int32 = Base("int32") Int64 = Base("int64") Int32x4 = Base("__mm128i") Int32x8 = Base("__mm256i") Int64x2 = Base("__mm128i") Int64x4 = Base("__mm256i") Float32 = Base("float") Float64 = Base("double") Float32x4 = Base("__m128") Float32x8 = Base("__mm256") Float64x2 = Base("__m128d") Float64x4 = Base("__mm256d") def IsVectorType(kind): if kind is Float32x4 or \ kind is Float32x8 or \ kind is Float64x2 or \ kind is Float64x4 or \ kind is Int32x4 or \ kind is Int32x8 or \ kind is Int64x2 or \ kind is Int64x4: return True if type(kind) == Ptr: return IsVectorType(kind.to) return False def IsArrayType(kind): return isinstance(kind, Array) # TODO: Make this ARCH dependent BitDepth= { Void: 8, Int: 32, Int32: 32, Int64: 64, Float32: 64, Float64: 64, } class SIMD(Enum): SSE = "sse" SSE2 = "sse2" AVX = "avx" AVX2 = "avx2" FMA3 = "fma3" AVX5 = "avx512" RegisterSize = { SIMD.SSE: 128, SIMD.SSE2: 128, SIMD.AVX: 256, SIMD.AVX2: 256, SIMD.FMA3: 256, SIMD.AVX5: 512, } SIMDSupport = { SIMD.SSE: set([Float32]), SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), } # TODO: Think of a better way to handle this def emit·load128(l, r): if l is not None: l.emit() emit(" = ") emit("_mm_loadu_sd(") r.emit() emit(")") def emit·broadcast256(l, r): l.emit() emit(" = _mm256_broadcastsd_pd(") r.emit() emit(")") def emit·load256(l, r): if l is not None: l.emit() emit(" = ") emit("_mm256_loadu_pd(") r.emit() emit(")") def emit·store256(l, r): emit("_mm256_storeu_pd(") l.emit() emit(", ") r.emit() emit(")") def emit·copy256(l, r): emit("_mm256_storeu_pd(") l.emit() emit(", ") emit("_mm256_loadu_pd(") r.emit() emit("))") # TODO: Typedefs... # ------------------------------------------ # C expressions class Expr(Emitter): def emit(): pass # Literals class Literal(Expr): def emit(self): emit(f"{self}") class I(Literal, int): def __new__(cls, i: int): return super(I, cls).__new__(cls, i) class F(Literal, float): def __new__(self, f: float): return super(F, self).__new__(cls, f) class S(Literal, str): def __new__(self, s: str): return super(S, self).__new__(cls, s) # Ident of symbol class Ident(Expr): def __init__(self, var): self.name = var.name self.var = var def __str__(self): return str(self.name) def __hash__(self): return hash(self.name) def __eq__(self, other): return type(self) == type(other) and self.name == other.name def emit(self): emit(f"{self.name}") # Unary operators class UnaryOp(Expr): def __init__(self, x: Expr): self.x = x def emit(self): pass class Deref(UnaryOp): method = { Ptr(Float64x4) : lambda x: emit·load256(None, x), } def emit(self): kind = GetType(self.x) if kind in self.method: self.method[kind](self.x) else: emit("*") self.x.emit() class Negate(UnaryOp): def emit(self): emit("~") self.x.emit() class Ref(UnaryOp): def emit(self): emit("&") self.x.emit() class Inc(UnaryOp): def __init__(self, x: Expr, pre=False): self.x = x self.pre = pre def emit(self): if self.pre: emit("++") self.x.emit() else: self.x.emit() emit("++") class Dec(UnaryOp): def __init__(self, x: Expr, pre=False): self.x = x self.pre = pre def emit(self): if self.pre: emit("--") self.x.emit() else: self.x.emit() emit("--") # Binary operators class BinaryOp(Expr): def __init__(self, left: Expr, right: Expr): self.l = left self.r = right def emit(self): pass # TODO: check types if they are vectorized and emit correct intrinsic class Add(BinaryOp): def emit(self): self.l.emit() emit(f" + ") self.r.emit() class Sub(BinaryOp): def emit(self): self.l.emit() emit(f" - ") self.r.emit() class Mul(BinaryOp): def emit(self): self.l.emit() emit(f" * ") self.r.emit() class Div(BinaryOp): def emit(self): self.l.emit() emit(f" / ") self.r.emit() class And(BinaryOp): def emit(self): self.l.emit() emit(f" & ") self.r.emit() class Xor(BinaryOp): def emit(self): self.l.emit() emit(f" ^ ") self.r.emit() class GT(BinaryOp): def emit(self): self.l.emit() emit(f" > ") self.r.emit() class LT(BinaryOp): def emit(self): self.l.emit() emit(f" < ") self.r.emit() class GE(BinaryOp): def emit(self): self.l.emit() emit(f" >= ") self.r.emit() class LE(BinaryOp): def emit(self): self.l.emit() emit(f" <= ") self.r.emit() class EQ(BinaryOp): def emit(self): self.l.emit() emit(f" == ") self.r.emit() # Loads a scalar class SIMDLoad(BinaryOp): def emit(self): self.l.emit() self.r.emit() # Loads data at address class SIMDLoadAt(BinaryOp): def emit(self): self.l.emit() emit(" = _mm256_loadu_pd(") self.r.emit() emit(")") # Assignment (stores) class Assign(Expr): def __init__(self, lhs: Expr, rhs: Expr): self.lhs = lhs self.rhs = rhs class Set(Assign): method = { (Float64x2, Ptr(Float64)) : emit·load128, (Float64x4, Float64x2) : emit·broadcast256, (Float64x4, Ptr(Float64x4)) : emit·load256, (Ptr(Float64x4), Float64x4) : emit·store256, (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256, } def emit(self): lhs = GetType(self.lhs) rhs = GetType(self.rhs) if (lhs, rhs) in self.method: self.method[(lhs, rhs)](self.lhs, self.rhs) else: self.lhs.emit() emit(f" = ") self.rhs.emit() class AddSet(Assign): def emit(self): self.lhs.emit() emit(f" += ") self.rhs.emit() class SubSet(Assign): def emit(self): self.lhs.emit() emit(f" -= ") self.rhs.emit() class MulSet(Assign): def emit(self): self.lhs.emit() emit(f" *= ") self.rhs.emit() class DivSet(Assign): def emit(self): self.lhs.emit() emit(f" /= ") self.rhs.emit() class Comma(Expr): def __init__(self, x: Expr, next: Expr): self.expr = (x, next) def emit(self): self.expr[0].emit() emit(", ") self.expr[1].emit() class Index(Expr): def __init__(self, x: Expr, i: Expr): self.x = x self.i = i def emit(self): self.x.emit() emit("[") self.i.emit() emit("]") class Paren(Expr): def __init__(self, x: Expr): self.x = x def emit(self): emit("(") self.x.emit() emit(")") class Call(Expr): def __init__(self, func: Func, args: List[Param]): self.func = func self.args = args def emit(self): emit(self.func.name) emit("(") if numel(self.args) > 0: self.args[0].emit() for arg in self.args[1:]: emit(", ") arg.emit() emit(")") def ExprList(*expr: Expr | List[Expr]): if numel(expr) > 1: x = expr[0] return Comma(x, ExprList(*expr[1:])) else: return expr[0] # Common assignments are kept globally # C statements class Stmt(Emitter): def emit(self): emit(";") class Empty(Stmt): pass class Block(Stmt): def __init__(self, *stmts: List[Stmt | Expr]): self.stmts = [ s if isinstance(s, Stmt) else StmtExpr(s) for s in stmts ] def emit(self): enter_scope() for i, stmt in enumerate(self.stmts): stmt.emit() if i < numel(self.stmts) - 1: emitln() exits_scope() class If(Stmt): def __init__(self, cond: Expr | List[Expr], then: Stmt | List[Stmt], orelse: Stmt | List[Stmt] | None = None): self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) self.then = Block(then) if isinstance(then, list) else then self.orelse = Block(orelse) if isinstance(orelse, list) else orelse def emit(self): emit("if (") self.cond.emit() emit(") ") self.then.emit() if self.orelse is not None: self.orelse.emit() class For(Stmt): def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None): self.init = init if isinstance(init, Expr) else ExprList(*init) self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) self.step = step if isinstance(step, Expr) else ExprList(*step) self.body = body if body is not None else Empty() def emit(self): emit("for (") if self.init is not None: self.init.emit() emit("; ") if self.cond is not None: self.cond.emit() emit("; ") if self.step is not None: self.step.emit() emit(") ") if isinstance(self.body, Block): self.body.emit() else: enter_scope() self.body.emit() exits_scope() def execute(self, stmt: Stmt | Expr): if isinstance(stmt, Expr): stmt = StmtExpr(stmt) if isinstance(self.body, Empty): self.body = stmt return elif not isinstance(self.body, Block): self.body = Block(self.body) self.body.stmts.append(stmt) class Return(Stmt): def __init__(self, val: Expr): self.val = val def emit(self): emitln() emit("return ") self.val.emit() super(Return, self).emit() class StmtExpr(Stmt): def __init__(self, x: Expr): self.x = x def emit(self): self.x.emit() super(StmtExpr, self).emit() class Mem(Enum): Auto = "" Static = "static" Register = "register" Typedef = "typedef" External = "extern" class Decl(Emitter): def __init__(self): pass def emit(self): pass class Func(Decl): def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): self.name = name self.ret = ret self.params = params if params is not None else [] self.vars = vars if vars is not None else [] self.stmts = body if body is not None else [] def emit(self): self.ret.emit() emitln() emit(self.name) emit("(") for i, p in enumerate(self.params): p.emittype() emit(" ") p.emitspec() if i < numel(self.params) - 1: emit(", ") emit(")\n") enter_scope() for var in self.vars: if isinstance(var, list): v = var[0] v.emittype() emit(" ") v.emitspec() for v in var[1:]: emit(", ") v.emitspec() else: var.emittype() emit(" ") var.emitspec() emit(";") emitln() if numel(self.vars) > 0: emitln() for stmt in self.stmts[:-1]: stmt.emit() emitln() if numel(self.stmts) > 0: self.stmts[-1].emit() exits_scope() emitln(2) def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: if var.name in [v.name for v in self.vars]: return self.vars.append(var) if numel(vars) == 0: return Ident(var) self.vars.extend(vars) idents = [Ident(var)] idents += [Ident(v) for v in vars] return idents def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]): def push(n): if isinstance(n, Stmt): self.stmts.append(n) elif isinstance(n, Expr): self.stmts.append(StmtExpr(n)) else: raise TypeError("unrecognized type for function") push(stmt) for arg in args: push(arg) def variables(self, *idents: List[str]) -> List[Expr]: vars = {v.name : v for v in self.vars + self.params} if numel(idents) == 1: return Ident(vars[idents[0]]) return [Ident(vars[ident]) for ident in idents] class Var(Decl): def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto): self.name = name self.type = type self.storage = storage def emit(self): emit(f"{self.name}") def emittype(self): if self.storage != Mem.Auto: emit(self.storage.value) emit(" ") self.type.emit() def emitspec(self): self.type.emitspec(self.name) @classmethod def copy(cls, other: Var): return cls(other.type, other.name, other.storage) class Param(Var): def __init__(self, type: Type, name: str): return super(Param, self).__init__(type, name, Mem.Auto) @classmethod def copy(cls, other: Param): return cls(other.type, other.name) def Params(*ps: List[Tuple(Type, str)]) -> List[Param]: return [Param(p[0], p[1]) for p in ps] def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: return [Var(p[0], p[1]) for p in ps] # ------------------------------------------------------------------------ # AST modification/production functions # ------------------------------------------ # basic (non-recursive) commands def Swap(x: Var, y: Var, tmp: Var) -> Stmt: return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) def IsLoop(s: Stmt) -> bool: return type(s) == For def EvenTo(x: Var, n: int) -> Var: return And(x, Negate(I(n-1))) # ------------------------------------------ # Expand: takes statement, indexed, and expands it x times # def Expand(s: Stmt, times: int) -> Block: # if not isinstance(s, StmtExpr): # raise TypeError(f"{type(x)} not supported by Expand operation") # return Block(StmtExpr(Step(s.x, i)) for i in range(times)) # ------------------------------------------ # repeat: takes command on an array and repeats it @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: raise TypeError(f"{type(x)} not supported by Repeat operation") @Repeat.register def _(x: Inc, times: int): return AddSet(x.x, I(times)) @Repeat.register def _(x: Dec, times: int): return DecSet(x.x, I(times)) @Repeat.register def _(x: Comma, times: int): return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: raise TypeError(f"{type(x)} not supported by Repeat operation") @Repeat.register def _(x: Inc, times: int): return AddSet(x.x, I(times)) @Repeat.register def _(x: Dec, times: int): return DecSet(x.x, I(times)) @Repeat.register def _(x: Comma, times: int): return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) # ------------------------------------------ # step: indexes an expression by i @singledispatch def Step(x: Expr, i: int) -> Expr: raise TypeError(f"{type(x)} not supported by Step operation") @Step.register def _(x: Comma, i: int): return Comma(Step(x.expr[0], i), Step(x.expr[1], i)) @Step.register def _(x: Assign, i: int): return type(x)(Step(x.lhs, i), Step(x.rhs, i)) @Step.register def _(x: BinaryOp, i: int): return type(x)(Step(x.l, i), Step(x.r, i)) @Step.register def _(x: UnaryOp, i: int): return type(x)(Step(x.x, i)) @Step.register def _(x: Ident, i: int): return x @Step.register def _(x: Deref, i: int): return Index(x.x, I(i)) @Step.register def _(ix: Index, i: int): return Index(ix.x, Add(ix.i, I(i))) # ------------------------------------------ # bfs search on statements in ast @singledispatch def Visit(s: Stmt, func): raise TypeError(f"{type(s)} not supported by Visit operation") @Visit.register def _(s: Empty, func): return @Visit.register def _(blk: Block, func): for stmt in blk.stmts: Visit(stmt, func) @Visit.register def _(jmp: If, func): func(jmp.cond) Visit(jmp.then, func) if jmp.orelse is not None: Visit(jmp.orelse, func) @Visit.register def _(loop: For, func): func(loop.init) func(loop.cond) func(loop.step) Visit(loop.body, func) @Visit.register def _(ret: Return, func): func(ret.val) @Visit.register def _(x: StmtExpr, func): func(x.x) # ------------------------------------------ # recreates a piece of an AST, allowing an arbitrary transformation of expression nodes @singledispatch def Make(s: Stmt, func): raise TypeError(f"{type(s)} not supported by Make operation") @Make.register def _(s: Empty, func): return Empty() @Make.register def _(blk: Block, func): return Block(*(Make(stmt, func) for stmt in blk.stmts)) @Make.register def _(loop: For, func): return For(func(loop.init), func(loop.cond), func(loop.step), Make(loop.body, func)) @Make.register def _(jmp: If, func): if jmp.orelse is not None: return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func)) else: return If(func(jmp.cond), Make(jmp.then, func)) @Make.register def _(ret: Return, func): return Return(func(ret.val)) @Make.register def _(x: StmtExpr, func): return StmtExpr(func(x.x)) # ------------------------------------------ # GetType function @singledispatch def GetType(x: Expr) -> Type: if x is None: return raise TypeError(f"{type(x)} not supported by GetType operation") @GetType.register def _(x: Empty): return Void @GetType.register def _(x: Empty): return Void @GetType.register def _(x: S): return Ptr(Byte) @GetType.register def _(x: I): return Int @GetType.register def _(x: F): return Float64 @GetType.register def _(x: Ident): return x.var.type @GetType.register def _(x: UnaryOp): return GetType(x.x) @GetType.register def _(x: Deref): return GetType(x.x).to @GetType.register def _(x: Ref): return Ptr(GetType(x.x)) @GetType.register def _(x: Index): base = GetType(x.x) if isinstance(base, Ptr): return base.to elif isinstance(base, Array): return base.base else: pass # raise TypeError(f"attempting to index type {base} of node {x.x}") # TODO: type checking for both @GetType.register def _(x: BinaryOp): lhs = GetType(x.l) rhs = GetType(x.r) return lhs @GetType.register def _(x: Var): return x.type @GetType.register def _(x: Assign): lhs = GetType(x.lhs) rhs = GetType(x.rhs) return lhs @GetType.register def _(x: Comma): return GetType(x.expr[1]) @GetType.register def _(x: Paren): return GetType(x.x) @GetType.register def _(x: Call): return x.func.ret # ------------------------------------------ # Transform function # Recurses down AST, and modifies Expr nodes # Returns a brand new tree structure @singledispatch def Transform(x: Expr, func: Callable[Expr, Expr]): raise TypeError(f"{type(x)} not supported by Transform operation") @Transform.register def _(x: Empty, func): return func(Empty()) @Transform.register def _(x: Literal, func): return func(type(x)(x)) @Transform.register def _(x: Ident, func): return func(Ident(x.var)) @Transform.register def _(x: UnaryOp, func): return func(UnaryOp(Transform(x.x, func))) @Transform.register def _(x: Assign, func): return func(type(x)(Transform(x.lhs, func), Transform(x.rhs, func))) @Transform.register def _(x: Deref, func): return func(Transform(Deref(x.x, func))) @Transform.register def _(x: Index, func): return func(Index(Transform(x.x, func), Transform(x.i, func))) @Transform.register def _(x: Comma, func): return func(Comma(Transform(x.expr[0], func), Transform(x.expr[1], func))) @Transform.register def _(x: BinaryOp, func): return func(type(x)(Transform(x.l, func), Transform(x.r, func))) # ------------------------------------------ # Filter function # Recurses down AST, and stores Expr nodes that satisfy condition in results @singledispatch def Filter(x: Expr, cond, results: List[Expr]): raise TypeError(f"{type(x)} not supported by Filter operation") @Filter.register def _(x: Empty, cond, results: List[Expr]): if cond(x): results.append(x) @Filter.register(Ident) @Filter.register(Literal) def _(x, cond, results: List[Expr]): if cond(x): results.append(x) @Filter.register def _(x: UnaryOp, cond, results: List[Expr]): if cond(x): results.append(x) Filter(op.x, cond, results) @Filter.register def _(s: Assign, cond, results: List[Expr]): if cond(s): results.append(s) Filter(s.lhs, cond, results) Filter(s.rhs, cond, results) @Filter.register def _(v: Deref, cond, results: List[Expr]): if cond(v): results.append(v) Filter(v.x, cond, results) @Filter.register def _(i: Index, cond, results: List[Expr]): if cond(i): results.append(i) Filter(i.x, cond, results) Filter(i.i, cond, results) @Filter.register def _(comma: Comma, cond, results: List[Expr]): if cond(comma): results.append(comma) Filter(comma.expr[0], cond, results) Filter(comma.expr[1], cond, results) @Filter.register def _(op: BinaryOp, cond, results: List[Expr]): if cond(op): results.append(op) Filter(op.l, cond, results) Filter(op.r, cond, results) # ------------------------------------------ # Eval function # Recurses down AST, and evaluates Expr nodes # Throws an error if tree can't be evaluated at compile time @singledispatch def Eval(x: Expr) -> object: raise TypeError(f"{type(x)} not supported by Eval operation") @Eval.register def _(x: Empty): pass @Eval.register(Ident) @Eval.register(S) def _(x): return sympy.symbols(f"{x}") @Eval.register(float) @Eval.register(int) @Eval.register(I) @Eval.register(F) def _(x): return x @Eval.register def _(x: Inc): return Eval(Add(Eval(op.x, cond, results), I(1))) @Eval.register def _(x: Dec): return Eval(Sub(Eval(op.x, cond, results), I(1))) # TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x)) @Eval.register def _(op: Add): l = Eval(op.l) r = Eval(op.r) return sympy.simplify(l + r) @Eval.register def _(op: Sub): l = Eval(op.l) r = Eval(op.r) return sympy.simplify(l - r) @Eval.register def _(op: Mul): l = Eval(op.l) r = Eval(op.r) return sympy.simplify(l * r) @Eval.register def _(op: Div): l = Eval(op.l) r = Eval(op.r) return sympy.simplify(l / r) @Eval.register def _(s: Assign): return Eval(s.rhs) @Eval.register def _(comma: Comma): return Eval(comma.expr[1]) # ------------------------------------------ # Leaf traversal def VarsUsed(stmt: Stmt) -> List[Var]: vars = [] Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars)) vars = set([v.var for v in vars]) return vars def AddrAccessed(stmt: Stmt): scalars = [] # variables that are accessed as scalars (single sites) vectors = {} # variables that are accessed as vectors (indexed/dereferenced) nodes = [] Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), scalars)) Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Index), nodes)) for node in nodes: vars = [] Filter(node.x, lambda x: isinstance(x, Ident), vars) if numel(vars) != 1: raise ValueError("multiple variables used in index expression not supported") vectors[vars[0]] = node.i if vars[0] in scalars: scalars.remove(vars[0]) return set(scalars), vectors # ------------------------------------------ # Large scale functions def Unroll(name: str, loop: For, times: int, ret: Ident = None, accumulator = None) -> (Func, List[Stmt]): # TODO: More sophisticated computation for length of loop if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") # pull off needed features of the loop i = loop.init.lhs.var vars = VarsUsed(loop.body) def asvector(v): if v != i: return Var(Array(v.type, times), v.name, v.storage) else: return Var.copy(v) n = loop.cond.r.var param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if (type(v) == Param and v != n)] stack = {v.name: asvector(v) for v in vars if type(v) == Var} # TODO: More sophisticated type checking if (type(n) != Param): raise TypeError(f"{type(n)} not implemented yet") if ret is None: kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values())) else: kernel = Func(f"{name}_kernel{times}", ret.var.type, param, list(stack.values())) len = kernel.declare(n) body = loop.body itor, itorp = kernel.variables("i", "ip") def mkarray(x: Expr, times: int): if isinstance(x, Ident): if type(x.var) == Var: x.var = stack[x.name] return x def step(x: Expr, times: int): if x == itor: return Add(x, I(times)) if isinstance(x, Ident): if type(x.var) == Var: return Index(x, I(times)) if not isinstance(x, Expr): raise ValueError(f"panic, hit type {type(x)}") return x expandedloop = Make(loop.body, lambda node: Transform(node, lambda x: mkarray(x, times))) kernel.execute( Set(len, EvenTo(Deref(itorp), times)), For(Set(itor, I(0)), LT(itor, len), Repeat(loop.step, times), body = Block(* (Make(expandedloop, lambda node: Transform(node, lambda x: step(x, i))) for i in range(times) ) ) ), Set(Deref(itorp), itor) ) if ret is not None: k = kernel.variables(ret.name) if k.var.type != ret.var.type: if accumulator is None: raise ValueError("If loop returns a value, an accumulator must be given") else: accumulator = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor), body = accumulator(I(0), itor, k) ) kernel.execute(accumulator) kernel.execute(Return(Index(ret, I(0)))) loop.init = None if ret is None: return kernel, [Set(itor, len), Call(kernel, [Ref(itor)] + param[1:]), loop] else: return kernel, [Set(itor, len), Set(ret, Call(kernel, [Ref(itor)] + param[1:])), loop] # Replaces all vectorizable loops inside Func with vectorized variants # Returns the new vectorized function (tagged by the SIMD chosen) class SymTab(object): def __init__(self): self.stack = {} self.addrs = {} def have(self, name): return name in self.stack or name in self.addrs def Vectorize(func: Func, isa: SIMD) -> Func: if isa != SIMD.AVX2: raise ValueError(f"ISA '{isa}' not currently implemented") vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) for stmt in func.stmts: if IsLoop(stmt): loop = For(stmt.init, stmt.cond, stmt.step) iterator = set([stmt.init.lhs]) body = stmt.body if type(body) != Block: # TODO: Think through this carefully... # This is coded for the accumulation step at the end of a kernel loop.cond.r = I(Eval(Div(loop.cond.r, 4))) loop.body = body else: instr = body.stmts # TODO: Remove hardcoded 4 -> should be function of types! # As of now, we have harcoded AVX2 <-> float64 relationship if numel(instr) % 4 != 0: raise ValueError("loop can not be vectorized, instructions can not be globbed equally") # TODO: Allow for non-sequential accesses? for i in range(0, numel(instr), 4): scalars = [] vectors = [] for j in range(4): s, v = AddrAccessed(instr[i+j]) s -= iterator # TODO: This is hacky scalars.append(frozenset(s)) vectors.append(v) # Test if code in uniform to allow for vectorization if numel(set(scalars)) != 1: raise ValueError("non uniform scalar accesses in consecutive line. can not vectorize") for j in range(1, 4): for v, idx in vectors[j].items(): if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1: print(f"{delta}") raise ValueError("non uniform vector accesses in consecutive line. can not vectorize") # If we made it to here, we have passed all checks. vectorize! vecs = vectors[0] if i == 0: syms = SymTab() for s in scalars[0]: intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) syms.stack[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) vfunc.execute(Set(intermediate, Ref(s))) vfunc.execute(Set(syms.stack[s.name], intermediate)) for v in vecs.keys(): # All params are treated AS addresses to load into vectorized registers if type(v.var) == Param: syms.addrs[v.name] = Var(Ptr(Float64x4), f"{v.name}256") # All stack declared parameters MUST be moved to vectorized registers else: assert IsArrayType(v.var.type), f"must be an array type, instead got {v.var.type}" nreg = v.var.type.len // 4 if nreg > 1: syms.stack[v.name] = vfunc.declare(Var(Array(Float64x4, nreg), f"{v.name}256")) else: syms.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256")) # IMPORTANT: We do a post-order traversal. # We transforms leaves (identifiers) first and then move back up the root def translate(x): if isinstance(x, Ident): if x.name in syms.stack: return syms.stack[x.name] elif x.name in syms.addrs: x.var = syms.addrs[x.name] if isinstance(x, Index): if type(x.x) == Ident: if x.x.name in syms.addrs: return Add(x.x, x.i) elif f"{x.x.name[:-3]}" in syms.stack: #NOTE: This is hacky. Think of something better return Index(x.x, I(Eval(Div(x.i, 4)))) if isinstance(x, BinaryOp): l, r = GetType(x.l), GetType(x.r) if IsVectorType(l) or IsVectorType(r): if type(l) == Ptr: x.l = Deref(x.l) if type(r) == Ptr: x.r = Deref(x.r) return x loop.execute(Make(instr[i], lambda node: Transform(node, translate))) vfunc.execute(loop) else: vfunc.execute(stmt) return vfunc def Strided(func: Func) -> Func: pass # ------------------------------------------------------------------------ # Point of testing def copy(): F = Func("blas·copy", Void, Params( (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") ), Vars( (Int, "i"), (Float64, "tmp") ) ) # could also declare like # F.declare( ... ) len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Swap(Index(x, i), Index(y, i), tmp) ) kernel, calls = Unroll("copy", loop, 8) kernel.emit() avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() F.execute(*calls) F.emit() def axpby(): F = Func("blas·axpby", Void, Params( (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") ), Vars( (Int, "i"), #(Float64, "tmp") ) ) # could also declare like # F.declare( ... ) # TODO: Increase ergonomics here... len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( Set(Index(y, i), Mul(b, Add(Index(y, i), Mul(a, Index(x, i)), ) ) ) ) kernel, calls = Unroll("axpby", loop, 8) kernel.emit() avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() F.execute(*calls) F.emit() def argmax(): F = Func("blas·argmax", Int, Params( (Int, "len"), (Ptr(Float64), "x"), ), Vars( (Int, "i"), (Int, "ix"), (Float64, "max") ) ) len, x, i, ix, max = F.variables("len", "x", "i", "ix", "max") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( If(GT(Index(x, i), max), Block( Set(ix, i), Set(max, Index(x, ix)), ) ) ) kernel, calls = Unroll("argmax", loop, 8, ix, lambda ires, icur, node: If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))), Block(Set(Index(node, ires), Index(node, icur))) ) ) kernel.emit() # avx256kernel = Vectorize(kernel, SIMD.AVX2) # avx256kernel.emit() F.execute(*calls, Return(ix)) F.execute(loop) F.emit() def dot(): F = Func("blas·dot", Float64, Params( (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") ), Vars( (Int, "i"), (Float64, "sum"), ) ) len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( AddSet(sum, Mul(Index(x, i), Index(y, i))) ) kernel, calls = Unroll("dot", loop, 16, sum, lambda ires, icur, node: StmtExpr(AddSet(Index(node, ires), Index(node, icur))) ) kernel.emit() avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() F.execute(*calls, Return(sum)) F.emit() def gemv(): F = Func("blas·gemv", Void, Params( (Int, "nrow"), (Int, "ncol"), (Float64, "a"), (Ptr(Float64), "m"), (Int, "incm"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") ), Vars( (Int, "r"), (Int, "c"), (Ptr(Float64), "row"), (Float64, "res") ) ) r, c, row, res = F.variables("r", "c", "row", "res") nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") loop = For(Set(r, I(0)), LT(r, nrow), Inc(r), Block( Set(row, Add(m, incm)), Set(res, I(0)), For(Set(c, I(0)), LT(c, ncol), Inc(c), AddSet(res, Mul(Index(row, c), Index(x, c))) ), Set(Index(y, r), Add(Mul(a, res), Mul(b, Index(y, r)))) ) ) F.execute(loop) F.emit() if __name__ == "__main__": emitheader() gemv() # dot() # argmax() # copy() # axpby() print(buffer)