diff options
Diffstat (limited to 'lib')
-rw-r--r-- | lib/c.py | 1579 |
1 files changed, 0 insertions, 1579 deletions
diff --git a/lib/c.py b/lib/c.py deleted file mode 100644 index bbc9ce3..0000000 --- a/lib/c.py +++ /dev/null @@ -1,1579 +0,0 @@ -#! /bin/python3 -from __future__ import annotations - -import os -import sys - -import sympy - -from enum import Enum -from typing import Callable, List, Tuple, Dict -from functools import singledispatch - -numel = len - -# ------------------------------------------------------------------------ -# String buffer - -level = 0 -buffer = "" - -def emit(s: str): - global buffer - buffer += s - -def emitln(n=1): - global buffer - buffer += "\n"*n + (" " * level) - -def enter_scope(): - global level - emit("{") - level += 1 - emitln() - -def exits_scope(): - global level - level -= 1 - emitln() - emit("}") - -def emitheader(): - emit("#include <u.h>\n") - emit("#include <libn.h>\n") - emit("#include <libmath.h>\n") - emitln() - -# ------------------------------------------------------------------------ -# Simple C AST -# TODO: Type checking - -# Abstract class everything will derive from -# All AST nodes will have an "emit" function that outputs formatted C code -class Emitter(object): - def emit(self): - pass - -# ------------------------------------------ -# Representation of a C type - -class Type(Emitter): - def emit(self): - pass - - def emitspec(self, var): - pass - -class Base(Type): - def __init__(self, name: str): - self.name = name - - def __hash__(self): - return hash(self.name) - - def __eq__(self, other): - return type(self) == type(other) and self.name == other.name - - def __str__(self): - return self.name - - def emit(self): - emit(self.name) - - def emitspec(self, ident): - emit(f"{ident}") - -# TODO: Operation lookup tables... - -class Ptr(Type): - def __init__(self, to: Type): - self.to = to - - def __str__(self): - return f"*{self.to}" - - def __eq__(self, other): - return type(self) == type(other) and self.to == other.to - - def __hash__(self): - return hash((hash(self.to), "*")) - - def emit(self): - self.to.emit() - - def emitspec(self, ident): - emit("*") - self.to.emitspec(ident) - -class Array(Type): - def __init__(self, base: Type, len: int): - self.base = base - self.len = len - - def __str__(self): - return f"[{self.len}]{self.base}" - - def __eq__(self, other): - return type(self) == type(other) and self.base == other.base - - def __hash__(self): - return hash((hash(self.base), self.len)) - - def emit(self): - self.base.emit() - - def emitspec(self, ident): - self.base.emitspec(ident) - emit(f"[{self.len}]") - -# Machine primitive types -Void = Base("void") -Error = Base("error") - -Byte = Base("byte") -String = Ptr(Byte) - -Int = Base("int") -Int8 = Base("int8") -Int16 = Base("int16") -Int32 = Base("int32") -Int64 = Base("int64") -Int32x4 = Base("__mm128i") -Int32x8 = Base("__mm256i") -Int64x2 = Base("__mm128i") -Int64x4 = Base("__mm256i") - -Float32 = Base("float") -Float64 = Base("double") -Float32x4 = Base("__m128") -Float32x8 = Base("__mm256") -Float64x2 = Base("__m128d") -Float64x4 = Base("__mm256d") - -def IsVectorType(kind): - if kind is Float32x4 or \ - kind is Float32x8 or \ - kind is Float64x2 or \ - kind is Float64x4 or \ - kind is Int32x4 or \ - kind is Int32x8 or \ - kind is Int64x2 or \ - kind is Int64x4: - return True - - if type(kind) == Ptr: - return IsVectorType(kind.to) - - return False - -def IsArrayType(kind): - return isinstance(kind, Array) - -# TODO: Make this ARCH dependent -BitDepth= { - Void: 8, - Int: 32, - Int32: 32, - Int64: 64, - Float32: 64, - Float64: 64, -} - -class SIMD(Enum): - SSE = "sse" - SSE2 = "sse2" - AVX = "avx" - AVX2 = "avx2" - FMA3 = "fma3" - AVX5 = "avx512" - -RegisterSize = { - SIMD.SSE: 128, - SIMD.SSE2: 128, - SIMD.AVX: 256, - SIMD.AVX2: 256, - SIMD.FMA3: 256, - SIMD.AVX5: 512, -} - -SIMDSupport = { - SIMD.SSE: set([Float32]), - SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), - SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), - SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), - SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), - SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), -} - -# TODO: Think of a better way to handle this -def emit·load128(l, r): - if l is not None: - l.emit() - emit(" = ") - emit("_mm_loadu_sd(") - r.emit() - emit(")") - -def emit·broadcast256(l, r): - l.emit() - emit(" = _mm256_broadcastsd_pd(") - r.emit() - emit(")") - -def emit·load256(l, r): - if l is not None: - l.emit() - emit(" = ") - emit("_mm256_loadu_pd(") - r.emit() - emit(")") - -def emit·store256(l, r): - emit("_mm256_storeu_pd(") - l.emit() - emit(", ") - r.emit() - emit(")") - -def emit·copy256(l, r): - emit("_mm256_storeu_pd(") - l.emit() - emit(", ") - emit("_mm256_loadu_pd(") - r.emit() - emit("))") - -# TODO: Typedefs... - -# ------------------------------------------ -# C expressions -class Expr(Emitter): - def emit(): - pass - -# Literals -class Literal(Expr): - def emit(self): - emit(f"{self}") - -class I(Literal, int): - def __new__(cls, i: int): - return super(I, cls).__new__(cls, i) - -class F(Literal, float): - def __new__(self, f: float): - return super(F, self).__new__(cls, f) - -class S(Literal, str): - def __new__(self, s: str): - return super(S, self).__new__(cls, s) - -# Ident of symbol -class Ident(Expr): - def __init__(self, var): - self.name = var.name - self.var = var - - def __str__(self): - return str(self.name) - - def __hash__(self): - return hash(self.name) - - def __eq__(self, other): - return type(self) == type(other) and self.name == other.name - - def emit(self): - emit(f"{self.name}") - - -# Unary operators -class UnaryOp(Expr): - def __init__(self, x: Expr): - self.x = x - - def emit(self): - pass - -class Deref(UnaryOp): - method = { - Ptr(Float64x4) : lambda x: emit·load256(None, x), - } - - def emit(self): - kind = GetType(self.x) - if kind in self.method: - self.method[kind](self.x) - else: - emit("*") - self.x.emit() - -class Negate(UnaryOp): - def emit(self): - emit("~") - self.x.emit() - -class Ref(UnaryOp): - def emit(self): - emit("&") - self.x.emit() - -class Inc(UnaryOp): - def __init__(self, x: Expr, pre=False): - self.x = x - self.pre = pre - - def emit(self): - if self.pre: - emit("++") - self.x.emit() - else: - self.x.emit() - emit("++") - -class Dec(UnaryOp): - def __init__(self, x: Expr, pre=False): - self.x = x - self.pre = pre - - def emit(self): - if self.pre: - emit("--") - self.x.emit() - else: - self.x.emit() - emit("--") - -# Binary operators -class BinaryOp(Expr): - def __init__(self, left: Expr, right: Expr): - self.l = left - self.r = right - - def emit(self): - pass - -# TODO: check types if they are vectorized and emit correct intrinsic -class Add(BinaryOp): - def emit(self): - self.l.emit() - emit(f" + ") - self.r.emit() - -class Sub(BinaryOp): - def emit(self): - self.l.emit() - emit(f" - ") - self.r.emit() - -class Mul(BinaryOp): - def emit(self): - self.l.emit() - emit(f" * ") - self.r.emit() - -class Div(BinaryOp): - def emit(self): - self.l.emit() - emit(f" / ") - self.r.emit() - -class And(BinaryOp): - def emit(self): - self.l.emit() - emit(f" & ") - self.r.emit() - -class Xor(BinaryOp): - def emit(self): - self.l.emit() - emit(f" ^ ") - self.r.emit() - -class GT(BinaryOp): - def emit(self): - self.l.emit() - emit(f" > ") - self.r.emit() - -class LT(BinaryOp): - def emit(self): - self.l.emit() - emit(f" < ") - self.r.emit() - -class GE(BinaryOp): - def emit(self): - self.l.emit() - emit(f" >= ") - self.r.emit() - -class LE(BinaryOp): - def emit(self): - self.l.emit() - emit(f" <= ") - self.r.emit() - -class EQ(BinaryOp): - def emit(self): - self.l.emit() - emit(f" == ") - self.r.emit() - -# Loads a scalar -class SIMDLoad(BinaryOp): - def emit(self): - self.l.emit() - self.r.emit() - -# Loads data at address -class SIMDLoadAt(BinaryOp): - def emit(self): - self.l.emit() - emit(" = _mm256_loadu_pd(") - self.r.emit() - emit(")") - -# Assignment (stores) -class Assign(Expr): - def __init__(self, lhs: Expr, rhs: Expr): - self.lhs = lhs - self.rhs = rhs - -class Set(Assign): - method = { - (Float64x2, Ptr(Float64)) : emit·load128, - (Float64x4, Float64x2) : emit·broadcast256, - (Float64x4, Ptr(Float64x4)) : emit·load256, - (Ptr(Float64x4), Float64x4) : emit·store256, - (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256, - } - def emit(self): - lhs = GetType(self.lhs) - rhs = GetType(self.rhs) - if (lhs, rhs) in self.method: - self.method[(lhs, rhs)](self.lhs, self.rhs) - else: - self.lhs.emit() - emit(f" = ") - self.rhs.emit() - -class AddSet(Assign): - def emit(self): - self.lhs.emit() - emit(f" += ") - self.rhs.emit() - -class SubSet(Assign): - def emit(self): - self.lhs.emit() - emit(f" -= ") - self.rhs.emit() - -class MulSet(Assign): - def emit(self): - self.lhs.emit() - emit(f" *= ") - self.rhs.emit() - -class DivSet(Assign): - def emit(self): - self.lhs.emit() - emit(f" /= ") - self.rhs.emit() - -class Comma(Expr): - def __init__(self, x: Expr, next: Expr): - self.expr = (x, next) - - def emit(self): - self.expr[0].emit() - emit(", ") - self.expr[1].emit() - -class Index(Expr): - def __init__(self, x: Expr, i: Expr): - self.x = x - self.i = i - - def emit(self): - self.x.emit() - emit("[") - self.i.emit() - emit("]") - -class Paren(Expr): - def __init__(self, x: Expr): - self.x = x - - def emit(self): - emit("(") - self.x.emit() - emit(")") - -class Call(Expr): - def __init__(self, func: Func, args: List[Param]): - self.func = func - self.args = args - - def emit(self): - emit(self.func.name) - emit("(") - if numel(self.args) > 0: - self.args[0].emit() - for arg in self.args[1:]: - emit(", ") - arg.emit() - emit(")") - -def ExprList(*expr: Expr | List[Expr]): - if numel(expr) > 1: - x = expr[0] - return Comma(x, ExprList(*expr[1:])) - else: - return expr[0] - -# Common assignments are kept globally -# C statements - -class Stmt(Emitter): - def emit(self): - emit(";") - -class Empty(Stmt): - pass - -class Block(Stmt): - def __init__(self, *stmts: List[Stmt | Expr]): - self.stmts = [ s if isinstance(s, Stmt) else StmtExpr(s) for s in stmts ] - - def emit(self): - enter_scope() - for i, stmt in enumerate(self.stmts): - stmt.emit() - if i < numel(self.stmts) - 1: - emitln() - - exits_scope() - -class If(Stmt): - def __init__(self, cond: Expr | List[Expr], then: Stmt | List[Stmt], orelse: Stmt | List[Stmt] | None = None): - self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) - self.then = Block(then) if isinstance(then, list) else then - self.orelse = Block(orelse) if isinstance(orelse, list) else orelse - - def emit(self): - emit("if (") - self.cond.emit() - emit(") ") - self.then.emit() - if self.orelse is not None: - self.orelse.emit() - -class For(Stmt): - def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None): - self.init = init if isinstance(init, Expr) else ExprList(*init) - self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) - self.step = step if isinstance(step, Expr) else ExprList(*step) - self.body = body if body is not None else Empty() - - def emit(self): - emit("for (") - if self.init is not None: - self.init.emit() - emit("; ") - if self.cond is not None: - self.cond.emit() - emit("; ") - if self.step is not None: - self.step.emit() - emit(") ") - - if isinstance(self.body, Block): - self.body.emit() - else: - enter_scope() - self.body.emit() - exits_scope() - - def execute(self, stmt: Stmt | Expr): - if isinstance(stmt, Expr): - stmt = StmtExpr(stmt) - - if isinstance(self.body, Empty): - self.body = stmt - return - elif not isinstance(self.body, Block): - self.body = Block(self.body) - self.body.stmts.append(stmt) - -class Return(Stmt): - def __init__(self, val: Expr): - self.val = val - - def emit(self): - emitln() - emit("return ") - self.val.emit() - super(Return, self).emit() - -class StmtExpr(Stmt): - def __init__(self, x: Expr): - self.x = x - - def emit(self): - self.x.emit() - super(StmtExpr, self).emit() - -class Mem(Enum): - Auto = "" - Static = "static" - Register = "register" - Typedef = "typedef" - External = "extern" - -class Decl(Emitter): - def __init__(self): - pass - - def emit(self): - pass - -class Func(Decl): - def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): - self.name = name - self.ret = ret - self.params = params if params is not None else [] - self.vars = vars if vars is not None else [] - self.stmts = body if body is not None else [] - - def emit(self): - self.ret.emit() - emitln() - emit(self.name) - emit("(") - for i, p in enumerate(self.params): - p.emittype() - emit(" ") - p.emitspec() - if i < numel(self.params) - 1: - emit(", ") - emit(")\n") - - enter_scope() - - for var in self.vars: - if isinstance(var, list): - v = var[0] - v.emittype() - emit(" ") - v.emitspec() - for v in var[1:]: - emit(", ") - v.emitspec() - else: - var.emittype() - emit(" ") - var.emitspec() - - emit(";") - emitln() - - if numel(self.vars) > 0: - emitln() - - for stmt in self.stmts[:-1]: - stmt.emit() - emitln() - if numel(self.stmts) > 0: - self.stmts[-1].emit() - - exits_scope() - emitln(2) - - def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: - if var.name in [v.name for v in self.vars]: - return - - self.vars.append(var) - if numel(vars) == 0: - return Ident(var) - - self.vars.extend(vars) - - idents = [Ident(var)] - idents += [Ident(v) for v in vars] - return idents - - def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]): - def push(n): - if isinstance(n, Stmt): - self.stmts.append(n) - elif isinstance(n, Expr): - self.stmts.append(StmtExpr(n)) - else: - raise TypeError("unrecognized type for function") - push(stmt) - for arg in args: - push(arg) - - def variables(self, *idents: List[str]) -> List[Expr]: - vars = {v.name : v for v in self.vars + self.params} - - if numel(idents) == 1: - return Ident(vars[idents[0]]) - return [Ident(vars[ident]) for ident in idents] - -class Var(Decl): - def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto): - self.name = name - self.type = type - self.storage = storage - - def emit(self): - emit(f"{self.name}") - - def emittype(self): - if self.storage != Mem.Auto: - emit(self.storage.value) - emit(" ") - self.type.emit() - - def emitspec(self): - self.type.emitspec(self.name) - - @classmethod - def copy(cls, other: Var): - return cls(other.type, other.name, other.storage) - -class Param(Var): - def __init__(self, type: Type, name: str): - return super(Param, self).__init__(type, name, Mem.Auto) - - @classmethod - def copy(cls, other: Param): - return cls(other.type, other.name) - -def Params(*ps: List[Tuple(Type, str)]) -> List[Param]: - return [Param(p[0], p[1]) for p in ps] - -def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: - return [Var(p[0], p[1]) for p in ps] - -# ------------------------------------------------------------------------ -# AST modification/production functions - -# ------------------------------------------ -# basic (non-recursive) commands - -def Swap(x: Var, y: Var, tmp: Var) -> Stmt: - return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) - -def IsLoop(s: Stmt) -> bool: - return type(s) == For - -def EvenTo(x: Var, n: int) -> Var: - return And(x, Negate(I(n-1))) - -# ------------------------------------------ -# Expand: takes statement, indexed, and expands it x times - -# def Expand(s: Stmt, times: int) -> Block: -# if not isinstance(s, StmtExpr): -# raise TypeError(f"{type(x)} not supported by Expand operation") - -# return Block(StmtExpr(Step(s.x, i)) for i in range(times)) - -# ------------------------------------------ -# repeat: takes command on an array and repeats it - -@singledispatch -def Repeat(x: Expr, times: int) -> Expr | List[Expr]: - raise TypeError(f"{type(x)} not supported by Repeat operation") - -@Repeat.register -def _(x: Inc, times: int): - return AddSet(x.x, I(times)) - -@Repeat.register -def _(x: Dec, times: int): - return DecSet(x.x, I(times)) - -@Repeat.register -def _(x: Comma, times: int): - return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) - -@singledispatch -def Repeat(x: Expr, times: int) -> Expr | List[Expr]: - raise TypeError(f"{type(x)} not supported by Repeat operation") - -@Repeat.register -def _(x: Inc, times: int): - return AddSet(x.x, I(times)) - -@Repeat.register -def _(x: Dec, times: int): - return DecSet(x.x, I(times)) - -@Repeat.register -def _(x: Comma, times: int): - return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times)) - -# ------------------------------------------ -# step: indexes an expression by i - -@singledispatch -def Step(x: Expr, i: int) -> Expr: - raise TypeError(f"{type(x)} not supported by Step operation") - -@Step.register -def _(x: Comma, i: int): - return Comma(Step(x.expr[0], i), Step(x.expr[1], i)) - -@Step.register -def _(x: Assign, i: int): - return type(x)(Step(x.lhs, i), Step(x.rhs, i)) - -@Step.register -def _(x: BinaryOp, i: int): - return type(x)(Step(x.l, i), Step(x.r, i)) - -@Step.register -def _(x: UnaryOp, i: int): - return type(x)(Step(x.x, i)) - -@Step.register -def _(x: Ident, i: int): - return x - -@Step.register -def _(x: Deref, i: int): - return Index(x.x, I(i)) - -@Step.register -def _(ix: Index, i: int): - return Index(ix.x, Add(ix.i, I(i))) - -# ------------------------------------------ -# bfs search on statements in ast - -@singledispatch -def Visit(s: Stmt, func): - raise TypeError(f"{type(s)} not supported by Visit operation") - -@Visit.register -def _(s: Empty, func): - return - -@Visit.register -def _(blk: Block, func): - for stmt in blk.stmts: - Visit(stmt, func) - -@Visit.register -def _(jmp: If, func): - func(jmp.cond) - Visit(jmp.then, func) - if jmp.orelse is not None: - Visit(jmp.orelse, func) - -@Visit.register -def _(loop: For, func): - func(loop.init) - func(loop.cond) - func(loop.step) - Visit(loop.body, func) - -@Visit.register -def _(ret: Return, func): - func(ret.val) - -@Visit.register -def _(x: StmtExpr, func): - func(x.x) - -# ------------------------------------------ -# recreates a piece of an AST, allowing an arbitrary transformation of expression nodes - -@singledispatch -def Make(s: Stmt, func): - raise TypeError(f"{type(s)} not supported by Make operation") - -@Make.register -def _(s: Empty, func): - return Empty() - -@Make.register -def _(blk: Block, func): - return Block(*(Make(stmt, func) for stmt in blk.stmts)) - -@Make.register -def _(loop: For, func): - return For(func(loop.init), func(loop.cond), func(loop.step), Make(loop.body, func)) - -@Make.register -def _(jmp: If, func): - if jmp.orelse is not None: - return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func)) - else: - return If(func(jmp.cond), Make(jmp.then, func)) - -@Make.register -def _(ret: Return, func): - return Return(func(ret.val)) - -@Make.register -def _(x: StmtExpr, func): - return StmtExpr(func(x.x)) - -# ------------------------------------------ -# GetType function - -@singledispatch -def GetType(x: Expr) -> Type: - if x is None: - return - raise TypeError(f"{type(x)} not supported by GetType operation") - -@GetType.register -def _(x: Empty): - return Void - -@GetType.register -def _(x: Empty): - return Void - -@GetType.register -def _(x: S): - return Ptr(Byte) - -@GetType.register -def _(x: I): - return Int - -@GetType.register -def _(x: F): - return Float64 - -@GetType.register -def _(x: Ident): - return x.var.type - -@GetType.register -def _(x: UnaryOp): - return GetType(x.x) - -@GetType.register -def _(x: Deref): - return GetType(x.x).to - -@GetType.register -def _(x: Ref): - return Ptr(GetType(x.x)) - -@GetType.register -def _(x: Index): - base = GetType(x.x) - if isinstance(base, Ptr): - return base.to - elif isinstance(base, Array): - return base.base - else: - pass - # raise TypeError(f"attempting to index type {base} of node {x.x}") - -# TODO: type checking for both - -@GetType.register -def _(x: BinaryOp): - lhs = GetType(x.l) - rhs = GetType(x.r) - return lhs - -@GetType.register -def _(x: Var): - return x.type - -@GetType.register -def _(x: Assign): - lhs = GetType(x.lhs) - rhs = GetType(x.rhs) - return lhs - -@GetType.register -def _(x: Comma): - return GetType(x.expr[1]) - -@GetType.register -def _(x: Paren): - return GetType(x.x) - -@GetType.register -def _(x: Call): - return x.func.ret - -# ------------------------------------------ -# Transform function -# Recurses down AST, and modifies Expr nodes -# Returns a brand new tree structure - -@singledispatch -def Transform(x: Expr, func: Callable[Expr, Expr]): - raise TypeError(f"{type(x)} not supported by Transform operation") - -@Transform.register -def _(x: Empty, func): - return func(Empty()) - -@Transform.register -def _(x: Literal, func): - return func(type(x)(x)) - -@Transform.register -def _(x: Ident, func): - return func(Ident(x.var)) - -@Transform.register -def _(x: UnaryOp, func): - return func(UnaryOp(Transform(x.x, func))) - -@Transform.register -def _(x: Assign, func): - return func(type(x)(Transform(x.lhs, func), Transform(x.rhs, func))) - -@Transform.register -def _(x: Deref, func): - return func(Transform(Deref(x.x, func))) - -@Transform.register -def _(x: Index, func): - return func(Index(Transform(x.x, func), Transform(x.i, func))) - -@Transform.register -def _(x: Comma, func): - return func(Comma(Transform(x.expr[0], func), Transform(x.expr[1], func))) - -@Transform.register -def _(x: BinaryOp, func): - return func(type(x)(Transform(x.l, func), Transform(x.r, func))) - -# ------------------------------------------ -# Filter function -# Recurses down AST, and stores Expr nodes that satisfy condition in results - -@singledispatch -def Filter(x: Expr, cond, results: List[Expr]): - raise TypeError(f"{type(x)} not supported by Filter operation") - -@Filter.register -def _(x: Empty, cond, results: List[Expr]): - if cond(x): - results.append(x) - -@Filter.register(Ident) -@Filter.register(Literal) -def _(x, cond, results: List[Expr]): - if cond(x): - results.append(x) - -@Filter.register -def _(x: UnaryOp, cond, results: List[Expr]): - if cond(x): - results.append(x) - Filter(op.x, cond, results) - -@Filter.register -def _(s: Assign, cond, results: List[Expr]): - if cond(s): - results.append(s) - Filter(s.lhs, cond, results) - Filter(s.rhs, cond, results) - -@Filter.register -def _(v: Deref, cond, results: List[Expr]): - if cond(v): - results.append(v) - Filter(v.x, cond, results) - -@Filter.register -def _(i: Index, cond, results: List[Expr]): - if cond(i): - results.append(i) - Filter(i.x, cond, results) - Filter(i.i, cond, results) - -@Filter.register -def _(comma: Comma, cond, results: List[Expr]): - if cond(comma): - results.append(comma) - Filter(comma.expr[0], cond, results) - Filter(comma.expr[1], cond, results) - -@Filter.register -def _(op: BinaryOp, cond, results: List[Expr]): - if cond(op): - results.append(op) - Filter(op.l, cond, results) - Filter(op.r, cond, results) - -# ------------------------------------------ -# Eval function -# Recurses down AST, and evaluates Expr nodes -# Throws an error if tree can't be evaluated at compile time - -@singledispatch -def Eval(x: Expr) -> object: - raise TypeError(f"{type(x)} not supported by Eval operation") - -@Eval.register -def _(x: Empty): - pass - -@Eval.register(Ident) -@Eval.register(S) -def _(x): - return sympy.symbols(f"{x}") - -@Eval.register(float) -@Eval.register(int) -@Eval.register(I) -@Eval.register(F) -def _(x): - return x - -@Eval.register -def _(x: Inc): - return Eval(Add(Eval(op.x, cond, results), I(1))) - -@Eval.register -def _(x: Dec): - return Eval(Sub(Eval(op.x, cond, results), I(1))) - -# TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x)) -@Eval.register -def _(op: Add): - l = Eval(op.l) - r = Eval(op.r) - return sympy.simplify(l + r) - -@Eval.register -def _(op: Sub): - l = Eval(op.l) - r = Eval(op.r) - return sympy.simplify(l - r) - -@Eval.register -def _(op: Mul): - l = Eval(op.l) - r = Eval(op.r) - return sympy.simplify(l * r) - -@Eval.register -def _(op: Div): - l = Eval(op.l) - r = Eval(op.r) - return sympy.simplify(l / r) - -@Eval.register -def _(s: Assign): - return Eval(s.rhs) - -@Eval.register -def _(comma: Comma): - return Eval(comma.expr[1]) - -# ------------------------------------------ -# Leaf traversal - -def VarsUsed(stmt: Stmt) -> List[Var]: - vars = [] - Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars)) - vars = set([v.var for v in vars]) - - return vars - -def AddrAccessed(stmt: Stmt): - scalars = [] # variables that are accessed as scalars (single sites) - vectors = {} # variables that are accessed as vectors (indexed/dereferenced) - nodes = [] - Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), scalars)) - Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Index), nodes)) - - for node in nodes: - vars = [] - Filter(node.x, lambda x: isinstance(x, Ident), vars) - if numel(vars) != 1: - raise ValueError("multiple variables used in index expression not supported") - vectors[vars[0]] = node.i - if vars[0] in scalars: - scalars.remove(vars[0]) - - return set(scalars), vectors - -# ------------------------------------------ -# Large scale functions - -def Unroll(name: str, loop: For, times: int, ret: Ident = None, accumulator = None) -> (Func, List[Stmt]): - # TODO: More sophisticated computation for length of loop - if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT): - raise TypeError(f"{type(loop.cond)} not supported in loop unrolling") - - # pull off needed features of the loop - i = loop.init.lhs.var - vars = VarsUsed(loop.body) - - def asvector(v): - if v != i: - return Var(Array(v.type, times), v.name, v.storage) - else: - return Var.copy(v) - - n = loop.cond.r.var - param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if (type(v) == Param and v != n)] - stack = {v.name: asvector(v) for v in vars if type(v) == Var} - - # TODO: More sophisticated type checking - if (type(n) != Param): - raise TypeError(f"{type(n)} not implemented yet") - - if ret is None: - kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values())) - else: - kernel = Func(f"{name}_kernel{times}", ret.var.type, param, list(stack.values())) - - len = kernel.declare(n) - - body = loop.body - itor, itorp = kernel.variables("i", "ip") - - def mkarray(x: Expr, times: int): - if isinstance(x, Ident): - if type(x.var) == Var: - x.var = stack[x.name] - return x - - def step(x: Expr, times: int): - if x == itor: - return Add(x, I(times)) - if isinstance(x, Ident): - if type(x.var) == Var: - return Index(x, I(times)) - if not isinstance(x, Expr): - raise ValueError(f"panic, hit type {type(x)}") - - return x - - expandedloop = Make(loop.body, lambda node: Transform(node, lambda x: mkarray(x, times))) - kernel.execute( - Set(len, EvenTo(Deref(itorp), times)), - For(Set(itor, I(0)), LT(itor, len), Repeat(loop.step, times), - body = Block(* - (Make(expandedloop, lambda node: - Transform(node, lambda x: step(x, i))) for i in range(times) - ) - ) - ), - Set(Deref(itorp), itor) - ) - - if ret is not None: - k = kernel.variables(ret.name) - if k.var.type != ret.var.type: - if accumulator is None: - raise ValueError("If loop returns a value, an accumulator must be given") - else: - accumulator = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor), - body = accumulator(I(0), itor, k) - ) - kernel.execute(accumulator) - - kernel.execute(Return(Index(ret, I(0)))) - - loop.init = None - if ret is None: - return kernel, [Set(itor, len), Call(kernel, [Ref(itor)] + param[1:]), loop] - else: - return kernel, [Set(itor, len), Set(ret, Call(kernel, [Ref(itor)] + param[1:])), loop] - -# Replaces all vectorizable loops inside Func with vectorized variants -# Returns the new vectorized function (tagged by the SIMD chosen) - -class SymTab(object): - def __init__(self): - self.stack = {} - self.addrs = {} - - def have(self, name): - return name in self.stack or name in self.addrs - -def Vectorize(func: Func, isa: SIMD) -> Func: - if isa != SIMD.AVX2: - raise ValueError(f"ISA '{isa}' not currently implemented") - - vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params) - for stmt in func.stmts: - if IsLoop(stmt): - loop = For(stmt.init, stmt.cond, stmt.step) - iterator = set([stmt.init.lhs]) - - body = stmt.body - if type(body) != Block: - # TODO: Think through this carefully... - # This is coded for the accumulation step at the end of a kernel - loop.cond.r = I(Eval(Div(loop.cond.r, 4))) - loop.body = body - else: - instr = body.stmts - # TODO: Remove hardcoded 4 -> should be function of types! - # As of now, we have harcoded AVX2 <-> float64 relationship - if numel(instr) % 4 != 0: - raise ValueError("loop can not be vectorized, instructions can not be globbed equally") - - # TODO: Allow for non-sequential accesses? - for i in range(0, numel(instr), 4): - scalars = [] - vectors = [] - for j in range(4): - s, v = AddrAccessed(instr[i+j]) - s -= iterator # TODO: This is hacky - scalars.append(frozenset(s)) - vectors.append(v) - - # Test if code in uniform to allow for vectorization - if numel(set(scalars)) != 1: - raise ValueError("non uniform scalar accesses in consecutive line. can not vectorize") - - for j in range(1, 4): - for v, idx in vectors[j].items(): - if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1: - print(f"{delta}") - raise ValueError("non uniform vector accesses in consecutive line. can not vectorize") - - # If we made it to here, we have passed all checks. vectorize! - vecs = vectors[0] - if i == 0: - syms = SymTab() - for s in scalars[0]: - intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128")) - syms.stack[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) - vfunc.execute(Set(intermediate, Ref(s))) - vfunc.execute(Set(syms.stack[s.name], intermediate)) - - for v in vecs.keys(): - # All params are treated AS addresses to load into vectorized registers - if type(v.var) == Param: - syms.addrs[v.name] = Var(Ptr(Float64x4), f"{v.name}256") - # All stack declared parameters MUST be moved to vectorized registers - else: - assert IsArrayType(v.var.type), f"must be an array type, instead got {v.var.type}" - nreg = v.var.type.len // 4 - if nreg > 1: - syms.stack[v.name] = vfunc.declare(Var(Array(Float64x4, nreg), f"{v.name}256")) - else: - syms.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256")) - - - # IMPORTANT: We do a post-order traversal. - # We transforms leaves (identifiers) first and then move back up the root - def translate(x): - if isinstance(x, Ident): - if x.name in syms.stack: - return syms.stack[x.name] - elif x.name in syms.addrs: - x.var = syms.addrs[x.name] - if isinstance(x, Index): - if type(x.x) == Ident: - if x.x.name in syms.addrs: - return Add(x.x, x.i) - elif f"{x.x.name[:-3]}" in syms.stack: #NOTE: This is hacky. Think of something better - return Index(x.x, I(Eval(Div(x.i, 4)))) - if isinstance(x, BinaryOp): - l, r = GetType(x.l), GetType(x.r) - if IsVectorType(l) or IsVectorType(r): - if type(l) == Ptr: - x.l = Deref(x.l) - if type(r) == Ptr: - x.r = Deref(x.r) - - return x - - loop.execute(Make(instr[i], lambda node: Transform(node, translate))) - - vfunc.execute(loop) - else: - vfunc.execute(stmt) - - return vfunc - - -def Strided(func: Func) -> Func: - pass - -# ------------------------------------------------------------------------ -# Point of testing - -def copy(): - F = Func("blas·copy", Void, - Params( - (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") - ), - Vars( - (Int, "i"), (Float64, "tmp") - ) - ) - # could also declare like - # F.declare( ... ) - - len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Swap(Index(x, i), Index(y, i), tmp) - ) - - kernel, calls = Unroll("copy", loop, 8) - kernel.emit() - - avx256kernel = Vectorize(kernel, SIMD.AVX2) - avx256kernel.emit() - - F.execute(*calls) - F.emit() - -def axpby(): - F = Func("blas·axpby", Void, - Params( - (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") - ), - Vars( - (Int, "i"), #(Float64, "tmp") - ) - ) - # could also declare like - # F.declare( ... ) - - # TODO: Increase ergonomics here... - len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - Set(Index(y, i), - Mul(b, - Add(Index(y, i), - Mul(a, Index(x, i)), - ) - ) - ) - ) - - kernel, calls = Unroll("axpby", loop, 8) - kernel.emit() - - avx256kernel = Vectorize(kernel, SIMD.AVX2) - avx256kernel.emit() - - F.execute(*calls) - F.emit() - -def argmax(): - F = Func("blas·argmax", Int, - Params( - (Int, "len"), (Ptr(Float64), "x"), - ), - Vars( - (Int, "i"), (Int, "ix"), (Float64, "max") - ) - ) - len, x, i, ix, max = F.variables("len", "x", "i", "ix", "max") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - If(GT(Index(x, i), max), - Block( - Set(ix, i), - Set(max, Index(x, ix)), - ) - ) - ) - kernel, calls = Unroll("argmax", loop, 8, ix, - lambda ires, icur, node: - If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))), - Block(Set(Index(node, ires), Index(node, icur))) - ) - ) - kernel.emit() - - # avx256kernel = Vectorize(kernel, SIMD.AVX2) - # avx256kernel.emit() - - F.execute(*calls, Return(ix)) - - F.execute(loop) - F.emit() - -def dot(): - F = Func("blas·dot", Float64, - Params( - (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") - ), - Vars( - (Int, "i"), (Float64, "sum"), - ) - ) - len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum") - - loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) - loop.execute( - AddSet(sum, Mul(Index(x, i), Index(y, i))) - ) - - kernel, calls = Unroll("dot", loop, 16, sum, - lambda ires, icur, node: - StmtExpr(AddSet(Index(node, ires), Index(node, icur))) - ) - kernel.emit() - - avx256kernel = Vectorize(kernel, SIMD.AVX2) - avx256kernel.emit() - - F.execute(*calls, Return(sum)) - - F.emit() - -def gemv(): - F = Func("blas·gemv", Void, - Params( - (Int, "nrow"), (Int, "ncol"), (Float64, "a"), (Ptr(Float64), "m"), (Int, "incm"), - (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y") - ), - Vars( - (Int, "r"), (Int, "c"), (Ptr(Float64), "row"), (Float64, "res") - ) - ) - r, c, row, res = F.variables("r", "c", "row", "res") - nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") - loop = For(Set(r, I(0)), LT(r, nrow), Inc(r), - Block( - Set(row, Add(m, incm)), - Set(res, I(0)), - For(Set(c, I(0)), LT(c, ncol), Inc(c), - AddSet(res, Mul(Index(row, c), Index(x, c))) - ), - Set(Index(y, r), Add(Mul(a, res), Mul(b, Index(y, r)))) - ) - ) - - F.execute(loop) - F.emit() - - -if __name__ == "__main__": - emitheader() - - gemv() - # dot() - # argmax() - # copy() - # axpby() - - print(buffer) |