#! /bin/python3 from __future__ import annotations import os import sys import sympy from enum import Enum from typing import List, Tuple, Dict from functools import singledispatch numel = len # ------------------------------------------------------------------------ # String buffer level = 0 buffer = "" def emit(s: str): global buffer buffer += s def emitln(n=1): global buffer buffer += "\n"*n + (" " * level) def enter_scope(): global level emit("{") level += 1 emitln() def exits_scope(): global level level -= 1 emitln() emit("}") def emitheader(): emit("#include \n") emit("#include \n") emit("#include \n") emitln() # ------------------------------------------------------------------------ # Simple C AST # TODO: Type checking # Abstract class everything will derive from # All AST nodes will have an "emit" function that outputs formatted C code class Emitter(object): def emit(self): pass # ------------------------------------------ # Representation of a C type class Type(Emitter): def emit(self): pass def emitspec(self, var): pass class Base(Type): def __init__(self, name: str): self.name = name def __hash__(self): return hash(self.name) def __eq__(self, other): return type(self) == type(other) and self.name == other.name def __str__(self): return self.name def emit(self): emit(self.name) def emitspec(self, ident): emit(f"{ident}") # TODO: Operation lookup tables... class Ptr(Type): def __init__(self, to: Type): self.to = to def __str__(self): return f"*{self.to}" def __eq__(self, other): return type(self) == type(other) and self.to == other.to def __hash__(self): return hash((hash(self.to), "*")) def emit(self): self.to.emit() def emitspec(self, ident): emit("*") self.to.emitspec(ident) class Array(Type): def __init__(self, base: Type, len: int): self.base = base self.len = len def __str__(self): return f"[{self.len}]{self.base}" def __eq__(self, other): return type(self) == type(other) and self.base == other.base def __hash__(self): return hash((hash(self.base), self.len)) def emit(self): self.base.emit() def emitspec(self, ident): self.base.emitspec(ident) emit(f"[{self.len}]") # Machine primitive types Void = Base("void") Error = Base("error") Byte = Base("byte") String = Ptr(Byte) Int = Base("int") Int8 = Base("int8") Int16 = Base("int16") Int32 = Base("int32") Int64 = Base("int64") Int32x4 = Base("__mm128i") Int32x8 = Base("__mm256i") Int64x2 = Base("__mm128i") Int64x4 = Base("__mm256i") Float32 = Base("float") Float64 = Base("double") Float32x4 = Base("__m128") Float32x8 = Base("__mm256") Float64x2 = Base("__m128d") Float64x4 = Base("__mm256d") def IsVectorType(kind): if kind is Float32x4 or \ kind is Float32x8 or \ kind is Float64x2 or \ kind is Float64x4 or \ kind is Int32x4 or \ kind is Int32x8 or \ kind is Int64x2 or \ kind is Int64x4: return True if type(kind) == Ptr: return IsVectorType(kind.to) return False # TODO: Make this ARCH dependent BitDepth= { Void: 8, Int: 32, Int32: 32, Int64: 64, Float32: 64, Float64: 64, } class SIMD(Enum): SSE = "sse" SSE2 = "sse2" AVX = "avx" AVX2 = "avx2" FMA3 = "fma3" AVX5 = "avx512" RegisterSize = { SIMD.SSE: 128, SIMD.SSE2: 128, SIMD.AVX: 256, SIMD.AVX2: 256, SIMD.FMA3: 256, SIMD.AVX5: 512, } SIMDSupport = { SIMD.SSE: set([Float32]), SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), } # TODO: Think of a better way to handle this def emit·load128(l, r): if l is not None: l.emit() emit(" = ") emit("_mm_loadu_sd(") r.emit() emit(")") def emit·broadcast256(l, r): l.emit() emit(" = _mm256_broadcastsd_pd(") r.emit() emit(")") def emit·load256(l, r): if l is not None: l.emit() emit(" = ") emit("_mm256_loadu_pd(") r.emit() emit(")") def emit·store256(l, r): emit("_mm256_storeu_pd(") l.emit() emit(", ") r.emit() emit(")") def emit·copy256(l, r): emit("_mm256_storeu_pd(") l.emit() emit(", ") emit("_mm256_loadu_pd(") r.emit() emit("))") # TODO: Typedefs... # ------------------------------------------ # C expressions class Expr(Emitter): def emit(): pass # Literals class Literal(): def emit(self): emit(f"{self}") class I(Literal, int): def __new__(cls, i: int): return super(I, cls).__new__(cls, i) class F(Literal, float): def __new__(self, f: float): return super(F, self).__new__(cls, f) class S(Literal, str): def __new__(self, s: str): return super(S, self).__new__(cls, s) # Ident of symbol class Ident(Expr): def __init__(self, var): self.name = var.name self.var = var def __str__(self): return str(self.name) def __hash__(self): return hash(self.name) def __eq__(self, other): return type(self) == type(other) and self.name == other.name def emit(self): emit(f"{self.name}") # Unary operators class UnaryOp(Expr): def __init__(self, x: Expr): self.x = x def emit(self): pass class Deref(UnaryOp): method = { Ptr(Float64x4) : lambda x: emit·load256(None, x), } def emit(self): kind = GetType(self.x) if kind in self.method: self.method[kind](self.x) else: emit("*") self.x.emit() class Negate(UnaryOp): def emit(self): emit("~") self.x.emit() class Ref(UnaryOp): def emit(self): emit("&") self.x.emit() class Inc(UnaryOp): def __init__(self, x: Expr, pre=False): self.x = x self.pre = pre def emit(self): if self.pre: emit("++") self.x.emit() else: self.x.emit() emit("++") class Dec(UnaryOp): def __init__(self, x: Expr, pre=False): self.x = x self.pre = pre def emit(self): if self.pre: emit("--") self.x.emit() else: self.x.emit() emit("--") # Binary operators class BinaryOp(Expr): def __init__(self, left: Expr, right: Expr): self.l = left self.r = right def emit(self): pass # TODO: check types if they are vectorized and emit correct intrinsic class Add(BinaryOp): def emit(self): self.l.emit() emit(f" + ") self.r.emit() class Sub(BinaryOp): def emit(self): self.l.emit() emit(f" - ") self.r.emit() class Mul(BinaryOp): def emit(self): self.l.emit() emit(f" * ") self.r.emit() class Div(BinaryOp): def emit(self): self.l.emit() emit(f" / ") self.r.emit() class And(BinaryOp): def emit(self): self.l.emit() emit(f" & ") self.r.emit() class Xor(BinaryOp): def emit(self): self.l.emit() emit(f" ^ ") self.r.emit() class GT(BinaryOp): def emit(self): self.l.emit() emit(f" > ") self.r.emit() class LT(BinaryOp): def emit(self): self.l.emit() emit(f" < ") self.r.emit() class GE(BinaryOp): def emit(self): self.l.emit() emit(f" >= ") self.r.emit() class LE(BinaryOp): def emit(self): self.l.emit() emit(f" <= ") self.r.emit() class EQ(BinaryOp): def emit(self): self.l.emit() emit(f" == ") self.r.emit() # Loads a scalar class SIMDLoad(BinaryOp): def emit(self): self.l.emit() self.r.emit() # Loads data at address class SIMDLoadAt(BinaryOp): def emit(self): self.l.emit() emit(" = _mm256_loadu_pd(") self.r.emit() emit(")") # Assignment (stores) class Assign(Expr): def __init__(self, lhs: Expr, rhs: Expr): self.lhs = lhs self.rhs = rhs class Set(Assign): method = { (Float64x2, Ptr(Float64)) : emit·load128, (Float64x4, Float64x2) : emit·broadcast256, (Float64x4, Ptr(Float64x4)) : emit·load256, (Ptr(Float64x4), Float64x4) : emit·store256, (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256, } def emit(self): lhs = GetType(self.lhs) rhs = GetType(self.rhs) if (lhs, rhs) in self.method: self.method[(lhs, rhs)](self.lhs, self.rhs) else: self.lhs.emit() emit(f" = ") self.rhs.emit() class AddSet(Assign): def emit(self): self.lhs.emit() emit(f" += ") self.rhs.emit() class SubSet(Assign): def emit(self): self.lhs.emit() emit(f" -= ") self.rhs.emit() class MulSet(Assign): def emit(self): self.lhs.emit() emit(f" *= ") self.rhs.emit() class DivSet(Assign): def emit(self): self.lhs.emit() emit(f" /= ") self.rhs.emit() class Comma(Expr): def __init__(self, x: Expr, next: Expr): self.expr = (x, next) def emit(self): self.expr[0].emit() emit(", ") self.expr[1].emit() class Index(Expr): def __init__(self, x: Expr, i: Expr): self.x = x self.i = i def emit(self): self.x.emit() emit("[") self.i.emit() emit("]") class Paren(Expr): def __init__(self, x: Expr): self.x = x def emit(self): emit("(") self.x.emit() emit(")") class Call(Expr): def __init__(self, func: Func, args: List[Param]): self.func = func self.args = args def emit(self): emit(self.func.name) emit("(") if numel(self.args) > 0: self.args[0].emit() for arg in self.args[1:]: emit(", ") arg.emit() emit(")") def ExprList(*expr: Expr | List[Expr]): if numel(expr) > 1: x = expr[0] return Comma(x, ExprList(*expr[1:])) else: return expr[0] # Common assignments are kept globally # C statements class Stmt(Emitter): def emit(self): emit(";") class Empty(Stmt): def __init__(self): pass def emit(self): super(Empty, self).emit() class Block(Stmt): def __init__(self, *stmts: List[Stmt]): self.stmts = list(stmts) def emit(self): enter_scope() for i, stmt in enumerate(self.stmts): stmt.emit() emit(";") if i < numel(self.stmts) - 1: emitln() exits_scope() class If(Stmt): def __init__(self, cond: Expr | List[Expr], then: Stmt | List[Stmt], orelse: Stmt | List[Stmt] | None = None): self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) self.then = Block(then) if isinstance(then, list) else then self.orelse = Block(orelse) if isinstance(orelse, list) else orelse def emit(self): emit("if (") self.cond.emit() emit(") ") self.then.emit() if self.orelse is not None: self.orelse.emit() class For(Stmt): def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None): self.init = init if isinstance(init, Expr) else ExprList(*init) self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) self.step = step if isinstance(step, Expr) else ExprList(*step) self.body = body if body is not None else Empty() def emit(self): emit("for (") if self.init is not None: self.init.emit() emit("; ") if self.cond is not None: self.cond.emit() emit("; ") if self.step is not None: self.step.emit() emit(") ") if isinstance(self.body, Block): self.body.emit() else: enter_scope() self.body.emit() exits_scope() def execute(self, stmt: Stmt | Expr): if isinstance(stmt, Expr): stmt = StmtExpr(stmt) if isinstance(self.body, Empty): self.body = stmt return elif not isinstance(self.body, Block): self.body = Block([self.body]) self.body.stmts.append(stmt) class Return(Stmt): def __init__(self, val: Expr): self.val = val def emit(self): emitln() emit("return ") self.val.emit() super(Return, self).emit() class StmtExpr(Stmt): def __init__(self, x: Expr): self.x = x def emit(self): self.x.emit() super(StmtExpr, self).emit() class Mem(Enum): Auto = "" Static = "static" Register = "register" Typedef = "typedef" External = "extern" class Decl(Emitter): def __init__(self): pass def emit(self): pass class Func(Decl): def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): self.name = name self.ret = ret self.params = params if params is not None else [] self.vars = vars if vars is not None else [] self.stmts = body if body is not None else [] def emit(self): self.ret.emit() emitln() emit(self.name) emit("(") for i, p in enumerate(self.params): p.emittype() emit(" ") p.emitspec() if i < numel(self.params) - 1: emit(", ") emit(")\n") enter_scope() for var in self.vars: if isinstance(var, list): v = var[0] v.emittype() emit(" ") v.emitspec() for v in var[1:]: emit(", ") v.emitspec() else: var.emittype() emit(" ") var.emitspec() emit(";") emitln() if numel(self.vars) > 0: emitln() for stmt in self.stmts[:-1]: stmt.emit() emitln() if numel(self.stmts) > 0: self.stmts[-1].emit() exits_scope() def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: if var.name in [v.name for v in self.vars]: return self.vars.append(var) if numel(vars) == 0: return Ident(var) self.vars.extend(vars) idents = [Ident(var)] idents += [Ident(v) for v in vars] return idents def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]): def push(n): if isinstance(n, Stmt): self.stmts.append(n) elif isinstance(n, Expr): self.stmts.append(StmtExpr(n)) else: raise TypeError("unrecognized type for function") push(stmt) for arg in args: push(arg) def variables(self, *idents: List[str]) -> List[Expr]: vars = {v.name : v for v in self.vars + self.params} if numel(idents) == 1: return Ident(vars[idents[0]]) return [Ident(vars[ident]) for ident in idents] class Var(Decl): def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto): self.name = name self.type = type self.storage = storage def emit(self): emit(f"{self.name}") def emittype(self): if self.storage != Mem.Auto: emit(self.storage.value) emit(" ") self.type.emit() def emitspec(self): self.type.emitspec(self.name) class Param(Var): def __init__(self, type: Type, name: str): return super(Param, self).__init__(type, name, Mem.Auto) def Params(*ps: List[Tuple(Type, str)]) -> List[Param]: return [Param(p[0], p[1]) for p in ps] def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: return [Var(p[0], p[1]) for p in ps] # ------------------------------------------------------------------------ # AST modification/production functions # ------------------------------------------ # basic (non-recursive) commands def Swap(x: Var, y: Var, tmp: Var) -> Stmt: return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) def IsLoop(s: Stmt) -> bool: return type(s) == For def Expand(s: StmtExpr, times: int) -> Block: if not isinstance(s, StmtExpr): raise TypeError(f"{type(x)} not supported by Expand operation") return Block(*[StmtExpr(Step(s.x, i)) for i in range(times)]) def EvenTo(x: Var, n: int) -> Var: return And(x, Negate(I(n-1))) @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: raise TypeError(f"{type(x)} not supported by Repeat operation") @Repeat.register def _(x: Inc, times: int): return AddSet(x.x, I(times)) @Repeat.register def _(x: Dec, times: int): return DecSet(x.x, I(times)) @Repeat.register def _(x: Comma, times: int): return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) @singledispatch def Repeat(x: Expr, times: int) -> Expr | List[Expr]: raise TypeError(f"{type(x)} not supported by Repeat operation") @Repeat.register def _(x: Inc, times: int): return AddSet(x.x, I(times)) @Repeat.register def _(x: Dec, times: int): return DecSet(x.x, I(times)) @Repeat.register def _(x: Comma, times: int): return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) # TODO: Parameterize the variables that should not be expanded @singledispatch def Step(x: Expr, i: int) -> Expr: raise TypeError(f"{type(x)} not supported by Step operation") @Step.register def _(x: Comma, i: int): return Comma(Step(x.expr[0], i), Step(x.expr[1], i)) @Step.register def _(x: Assign, i: int): return type(x)(Step(x.lhs, i), Step(x.rhs, i)) @Step.register def _(x: BinaryOp, i: int): return type(x)(Step(x.l, i), Step(x.r, i)) @Step.register def _(x: UnaryOp, i: int): return type(x)(Step(x.x, i)) @Step.register def _(x: Ident, i: int): return x @Step.register def _(x: Deref, i: int): return Index(x.x, I(i)) @Step.register def _(ix: Index, i: int): return Index(ix.x, Add(ix.i, I(i))) # ------------------------------------------ # bfs search on statements in ast @singledispatch def Visit(s: Stmt, func): raise TypeError(f"{type(s)} not supported by Visit operation") @Visit.register def _(s: Empty, func): return @Visit.register def _(blk: Block, func): for stmt in blk.stmts: func(stmt) @Visit.register def _(jmp: If, func): func(jmp.cond) func(jmp.then) if jmp.orelse is not None: func(jmp.orelse) @Visit.register def _(loop: For, func): func(loop.init) func(loop.cond) func(loop.step) func(loop.body) @Visit.register def _(ret: Return, func): func(ret.val) @Visit.register def _(x: StmtExpr, func): func(x.x) # ------------------------------------------ # recreates a piece of an AST, allowing an arbitrary transformation of expression nodes @singledispatch def Make(s: Stmt, func): raise TypeError(f"{type(s)} not supported by Make operation") @Make.register def _(s: Empty, func): return Empty() @Make.register def _(blk: Block, func): return Block(*[func(stmt) for stmt in blk.Stmts]) @Make.register def _(loop: For, func): return For(func(loop.init), func(loop.cond), func(loop.step), func(loop.body)) @Make.register def _(jmp: If, func): return If(func(jmp.cond), func(jmp.then), func(loop.orelse) if loop.orelse is not None else None) @Make.register def _(ret: Return, func): return Return(func(ret.val)) @Make.register def _(x: StmtExpr, func): return StmtExpr(func(x.x)) # ------------------------------------------ # GetType function @singledispatch def GetType(x: Expr) -> Type: if x is None: return raise TypeError(f"{type(x)} not supported by GetType operation") @GetType.register def _(x: Empty): return Void @GetType.register def _(x: Empty): return Void @GetType.register def _(x: S): return Ptr(Byte) @GetType.register def _(x: I): return Int @GetType.register def _(x: F): return Float64 @GetType.register def _(x: Ident): return x.var.type @GetType.register def _(x: UnaryOp): return GetType(x.x) @GetType.register def _(x: Deref): return GetType(x.x).to @GetType.register def _(x: Ref): return Ptr(GetType(x.x)) @GetType.register def _(x: Index): base = GetType(x.x) if isinstance(base, Ptr): return base.to elif isinstance(base, Array): return base.base else: raise TypeError(f"attempting to index type {base}") # TODO: type checking for both @GetType.register def _(x: BinaryOp): lhs = GetType(x.l) rhs = GetType(x.r) return lhs @GetType.register def _(x: Var): return x.type @GetType.register def _(x: Assign): lhs = GetType(x.lhs) rhs = GetType(x.rhs) return lhs @GetType.register def _(x: Comma): return GetType(x.expr[1]) @GetType.register def _(x: Paren): return GetType(x.x) @GetType.register def _(x: Call): return x.func.ret # ------------------------------------------ # Transform function # Recurses down AST, and modifies Expr nodes # Returns a brand new tree structure @singledispatch def Transform(x: Expr, func): raise TypeError(f"{type(x)} not supported by Transform operation") @Transform.register def _(x: Empty, func): return func(Empty()) @Transform.register def _(x: Literal, func): return func(type(x)(x)) @Transform.register def _(x: Ident, func): return func(Ident(x.var)) @Transform.register def _(x: UnaryOp, func): return func(UnaryOp(Transform(x.x, func))) @Transform.register def _(x: Assign, func): return func(type(x)(Transform(x.lhs, func), Transform(x.rhs, func))) @Transform.register def _(x: Deref, func): return func(Transform(Deref(x.x, func))) @Transform.register def _(x: Index, func): return func(Index(Transform(x.x, func), Transform(x.i, func))) @Transform.register def _(x: Comma, func): return func(Comma(Transform(x.expr[0], func), Transform(x.expr[1], func))) @Transform.register def _(x: BinaryOp, func): return func(type(x)(Transform(x.l, func), Transform(x.r, func))) # ------------------------------------------ # Filter function # Recurses down AST, and stores Expr nodes that satisfy condition in results @singledispatch def Filter(x: Expr, cond, results: List[Expr]): raise TypeError(f"{type(x)} not supported by Filter operation") @Filter.register def _(x: Empty, cond, results: List[Expr]): if cond(x): results.append(x) @Filter.register(Ident) @Filter.register(Literal) def _(x, cond, results: List[Expr]): if cond(x): results.append(x) @Filter.register def _(x: UnaryOp, cond, results: List[Expr]): if cond(x): results.append(x) Filter(op.x, cond, results) @Filter.register def _(s: Assign, cond, results: List[Expr]): if cond(s): results.append(s) Filter(s.lhs, cond, results) Filter(s.rhs, cond, results) @Filter.register def _(v: Deref, cond, results: List[Expr]): if cond(v): results.append(v) Filter(v.x, cond, results) @Filter.register def _(i: Index, cond, results: List[Expr]): if cond(i): results.append(i) Filter(i.x, cond, results) Filter(i.i, cond, results) @Filter.register def _(comma: Comma, cond, results: List[Expr]): if cond(comma): results.append(comma) Filter(comma.expr[0], cond, results) Filter(comma.expr[1], cond, results) @Filter.register def _(op: BinaryOp, cond, results: List[Expr]): if cond(op): results.append(op) Filter(op.l, cond, results) Filter(op.r, cond, results) # ------------------------------------------ # Eval function # Recurses down AST, and evaluates Expr nodes # Throws an error if tree can't be evaluated at compile time @singledispatch def Eval(x: Expr) -> object: raise TypeError(f"{type(s)} not supported by Eval operation") @Eval.register def _(x: Empty): pass @Eval.register(Ident) @Eval.register(S) def _(x): return sympy.symbols(f"{x}") @Eval.register(I) @Eval.register(F) def _(x): return x @Eval.register def _(x: Inc): return Eval(Add(Eval(op.x, cond, results), I(1))) @Eval.register def _(x: Dec): return Eval(Sub(Eval(op.x, cond, results), I(1))) # TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x)) @Eval.register def _(op: Add): l = Eval(op.l) r = Eval(op.r) return sympy.simplify(l + r) @Eval.register def _(op: Sub): l = Eval(op.l) r = Eval(op.r) return sympy.simplify(l - r) @Eval.register def _(s: Assign): return Eval(s.rhs) @Eval.register def _(comma: Comma): return Eval(comma.expr[1]) # ------------------------------------------ # Leaf traversal def VarsUsed(stmt: Stmt) -> List[Var]: vars = [] Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars)) vars = set([v.var for v in vars]) return vars def AddrAccessed(stmt: Stmt): scalars = [] # variables that are accessed as scalars (single sites) vectors = {} # variables that are accessed as vectors (indexed/dereferenced) nodes = [] Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), scalars)) Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Index), nodes)) for node in nodes: vars = [] Filter(node.x, lambda x: isinstance(x, Ident), vars) if numel(vars) != 1: raise ValueError("multiple variables used in index expression not supported") vectors[vars[0]] = node.i if vars[0] in scalars: scalars.remove(vars[0]) return set(scalars), vectors # ------------------------------------------ # Large scale functions def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call): # TODO: More sophisticated computation for length of loop if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") # pull off needed features of the loop i = loop.init.lhs.var vars = VarsUsed(loop.body) params = [v for v in vars if type(v) == Param] stacks = [v for v in vars if type(v) == Var] # TODO: More sophisticated type checking n = loop.cond.r.var if (type(n) != Param): raise TypeError(f"{type(n)} not implemented yet") params = [n] + params kernel = Func(f"{name}_kernel{times}", Int, params, stacks) i = kernel.variables("i") body = loop.body kernel.execute( Set(n, EvenTo(n, times)), For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times), body=Expand(loop.body, times) ), Return(n) ) loop.init = None return loop, kernel, Call(kernel, params) # Replaces all vectorizable loops inside Func with vectorized variants # Returns the new vectorized function (tagged by the SIMD chosen) def Vectorize(func: Func, isa: SIMD) -> Func: if isa != SIMD.AVX2: raise ValueError(f"ISA '{isa}' not currently implemented") vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) for stmt in func.stmts: if IsLoop(stmt): loop = For(stmt.init, stmt.cond, stmt.step) iterator = set([stmt.init.lhs]) body = stmt.body if type(body) != Block: print("could not vectorize loop, skipping") loop.body = body else: instr = body.stmts # TODO: Remove hardcoded 4 -> should be function of types! # As of now, we have harcoded AVX2 <-> float64 relationship if numel(instr) % 4 != 0: raise ValueError("loop can not be vectorized, instructions can not be globbed equally") # TODO: Allow for non-sequential accesses? s_symtab, v_symtab = {}, {} for i in range(0, numel(instr), 4): scalars = [] vectors = [] for j in range(4): s, v = AddrAccessed(instr[i+j]) s -= iterator # TODO: This is hacky scalars.append(frozenset(s)) vectors.append(v) # Test if code in uniform to allow for vectorization if numel(set(scalars)) != 1: raise ValueError("non uniform scalar accesses in consecutive line. can not vectorize") for j in range(1, 4): for v, idx in vectors[j].items(): if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1: print(f"{delta}") raise ValueError("non uniform vector accesses in consecutive line. can not vectorize") # If we made it to here, we have passed all checks. vectorize! # Create symbol table if i == 0: for s in scalars[0]: if type(s.var) == Param: intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) vfunc.execute(Set(intermediate, Ref(s))) vfunc.execute(Set(s_symtab[s.name], intermediate)) else: s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) for v in vectors[0].keys(): v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256") # IMPORTANT: We do a post-order traversal. # We transforms leaves (identifiers) first and then move back up the root def translate(x): if isinstance(x, Ident): if x.name in s_symtab: return s_symtab[x.name] elif x.name in v_symtab: x.var = v_symtab[x.name] if isinstance(x, Index): if type(x.x) == Ident and x.x.name in v_symtab: return Add(x.x, x.i) if isinstance(x, BinaryOp): l, r = GetType(x.l), GetType(x.r) if IsVectorType(l) or IsVectorType(r): if type(l) == Ptr: x.l = Deref(x.l) if type(r) == Ptr: x.r = Deref(x.r) return x loop.execute(Make(instr[i], lambda node: Transform(node, translate))) vfunc.execute(loop) else: vfunc.execute(stmt) return vfunc def Strided(func: Func) -> Func: pass # ------------------------------------------------------------------------ # Point of testing def axpby(): F = Func("blas·axpby", Void, Params( (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") ), Vars( (Int, "i"), #(Float64, "tmp") ) ) # could also declare like # F.declare( ... ) # TODO: Increase ergonomics here... len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) loop.execute( Set(Index(y, i), Mul(b, Add(Index(y, i), Mul(a, Index(x, i)), ) ) ) ) rem, kernel, call = Unroll(loop, 16, "axpy") kernel.emit() emitln(2) avx256kernel = Vectorize(kernel, SIMD.AVX2) avx256kernel.emit() emitln(2) F.execute(Set(i, call), rem) F.emit() print(buffer) if __name__ == "__main__": emitheader() F = Func("blas·argmax", Int, Params( (Int, "len"), (Ptr(Float64), "x"), ), Vars( (Int, "i"), (Int, "ix"), (Float64, "max") ) ) len, x, i, ix, max = F.variables("len", "x", "i", "ix", "max") loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) loop.execute( If(GT(Index(x, i), max), Block( Set(ix, i), Set(max, Index(x, i)), ) ) ) F.execute(loop) F.emit() print(buffer)