aboutsummaryrefslogtreecommitdiff
path: root/lib/c.py
diff options
context:
space:
mode:
Diffstat (limited to 'lib/c.py')
-rw-r--r--lib/c.py1579
1 files changed, 0 insertions, 1579 deletions
diff --git a/lib/c.py b/lib/c.py
deleted file mode 100644
index bbc9ce3..0000000
--- a/lib/c.py
+++ /dev/null
@@ -1,1579 +0,0 @@
-#! /bin/python3
-from __future__ import annotations
-
-import os
-import sys
-
-import sympy
-
-from enum import Enum
-from typing import Callable, List, Tuple, Dict
-from functools import singledispatch
-
-numel = len
-
-# ------------------------------------------------------------------------
-# String buffer
-
-level = 0
-buffer = ""
-
-def emit(s: str):
- global buffer
- buffer += s
-
-def emitln(n=1):
- global buffer
- buffer += "\n"*n + (" " * level)
-
-def enter_scope():
- global level
- emit("{")
- level += 1
- emitln()
-
-def exits_scope():
- global level
- level -= 1
- emitln()
- emit("}")
-
-def emitheader():
- emit("#include <u.h>\n")
- emit("#include <libn.h>\n")
- emit("#include <libmath.h>\n")
- emitln()
-
-# ------------------------------------------------------------------------
-# Simple C AST
-# TODO: Type checking
-
-# Abstract class everything will derive from
-# All AST nodes will have an "emit" function that outputs formatted C code
-class Emitter(object):
- def emit(self):
- pass
-
-# ------------------------------------------
-# Representation of a C type
-
-class Type(Emitter):
- def emit(self):
- pass
-
- def emitspec(self, var):
- pass
-
-class Base(Type):
- def __init__(self, name: str):
- self.name = name
-
- def __hash__(self):
- return hash(self.name)
-
- def __eq__(self, other):
- return type(self) == type(other) and self.name == other.name
-
- def __str__(self):
- return self.name
-
- def emit(self):
- emit(self.name)
-
- def emitspec(self, ident):
- emit(f"{ident}")
-
-# TODO: Operation lookup tables...
-
-class Ptr(Type):
- def __init__(self, to: Type):
- self.to = to
-
- def __str__(self):
- return f"*{self.to}"
-
- def __eq__(self, other):
- return type(self) == type(other) and self.to == other.to
-
- def __hash__(self):
- return hash((hash(self.to), "*"))
-
- def emit(self):
- self.to.emit()
-
- def emitspec(self, ident):
- emit("*")
- self.to.emitspec(ident)
-
-class Array(Type):
- def __init__(self, base: Type, len: int):
- self.base = base
- self.len = len
-
- def __str__(self):
- return f"[{self.len}]{self.base}"
-
- def __eq__(self, other):
- return type(self) == type(other) and self.base == other.base
-
- def __hash__(self):
- return hash((hash(self.base), self.len))
-
- def emit(self):
- self.base.emit()
-
- def emitspec(self, ident):
- self.base.emitspec(ident)
- emit(f"[{self.len}]")
-
-# Machine primitive types
-Void = Base("void")
-Error = Base("error")
-
-Byte = Base("byte")
-String = Ptr(Byte)
-
-Int = Base("int")
-Int8 = Base("int8")
-Int16 = Base("int16")
-Int32 = Base("int32")
-Int64 = Base("int64")
-Int32x4 = Base("__mm128i")
-Int32x8 = Base("__mm256i")
-Int64x2 = Base("__mm128i")
-Int64x4 = Base("__mm256i")
-
-Float32 = Base("float")
-Float64 = Base("double")
-Float32x4 = Base("__m128")
-Float32x8 = Base("__mm256")
-Float64x2 = Base("__m128d")
-Float64x4 = Base("__mm256d")
-
-def IsVectorType(kind):
- if kind is Float32x4 or \
- kind is Float32x8 or \
- kind is Float64x2 or \
- kind is Float64x4 or \
- kind is Int32x4 or \
- kind is Int32x8 or \
- kind is Int64x2 or \
- kind is Int64x4:
- return True
-
- if type(kind) == Ptr:
- return IsVectorType(kind.to)
-
- return False
-
-def IsArrayType(kind):
- return isinstance(kind, Array)
-
-# TODO: Make this ARCH dependent
-BitDepth= {
- Void: 8,
- Int: 32,
- Int32: 32,
- Int64: 64,
- Float32: 64,
- Float64: 64,
-}
-
-class SIMD(Enum):
- SSE = "sse"
- SSE2 = "sse2"
- AVX = "avx"
- AVX2 = "avx2"
- FMA3 = "fma3"
- AVX5 = "avx512"
-
-RegisterSize = {
- SIMD.SSE: 128,
- SIMD.SSE2: 128,
- SIMD.AVX: 256,
- SIMD.AVX2: 256,
- SIMD.FMA3: 256,
- SIMD.AVX5: 512,
-}
-
-SIMDSupport = {
- SIMD.SSE: set([Float32]),
- SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
- SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
- SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
- SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
- SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
-}
-
-# TODO: Think of a better way to handle this
-def emit·load128(l, r):
- if l is not None:
- l.emit()
- emit(" = ")
- emit("_mm_loadu_sd(")
- r.emit()
- emit(")")
-
-def emit·broadcast256(l, r):
- l.emit()
- emit(" = _mm256_broadcastsd_pd(")
- r.emit()
- emit(")")
-
-def emit·load256(l, r):
- if l is not None:
- l.emit()
- emit(" = ")
- emit("_mm256_loadu_pd(")
- r.emit()
- emit(")")
-
-def emit·store256(l, r):
- emit("_mm256_storeu_pd(")
- l.emit()
- emit(", ")
- r.emit()
- emit(")")
-
-def emit·copy256(l, r):
- emit("_mm256_storeu_pd(")
- l.emit()
- emit(", ")
- emit("_mm256_loadu_pd(")
- r.emit()
- emit("))")
-
-# TODO: Typedefs...
-
-# ------------------------------------------
-# C expressions
-class Expr(Emitter):
- def emit():
- pass
-
-# Literals
-class Literal(Expr):
- def emit(self):
- emit(f"{self}")
-
-class I(Literal, int):
- def __new__(cls, i: int):
- return super(I, cls).__new__(cls, i)
-
-class F(Literal, float):
- def __new__(self, f: float):
- return super(F, self).__new__(cls, f)
-
-class S(Literal, str):
- def __new__(self, s: str):
- return super(S, self).__new__(cls, s)
-
-# Ident of symbol
-class Ident(Expr):
- def __init__(self, var):
- self.name = var.name
- self.var = var
-
- def __str__(self):
- return str(self.name)
-
- def __hash__(self):
- return hash(self.name)
-
- def __eq__(self, other):
- return type(self) == type(other) and self.name == other.name
-
- def emit(self):
- emit(f"{self.name}")
-
-
-# Unary operators
-class UnaryOp(Expr):
- def __init__(self, x: Expr):
- self.x = x
-
- def emit(self):
- pass
-
-class Deref(UnaryOp):
- method = {
- Ptr(Float64x4) : lambda x: emit·load256(None, x),
- }
-
- def emit(self):
- kind = GetType(self.x)
- if kind in self.method:
- self.method[kind](self.x)
- else:
- emit("*")
- self.x.emit()
-
-class Negate(UnaryOp):
- def emit(self):
- emit("~")
- self.x.emit()
-
-class Ref(UnaryOp):
- def emit(self):
- emit("&")
- self.x.emit()
-
-class Inc(UnaryOp):
- def __init__(self, x: Expr, pre=False):
- self.x = x
- self.pre = pre
-
- def emit(self):
- if self.pre:
- emit("++")
- self.x.emit()
- else:
- self.x.emit()
- emit("++")
-
-class Dec(UnaryOp):
- def __init__(self, x: Expr, pre=False):
- self.x = x
- self.pre = pre
-
- def emit(self):
- if self.pre:
- emit("--")
- self.x.emit()
- else:
- self.x.emit()
- emit("--")
-
-# Binary operators
-class BinaryOp(Expr):
- def __init__(self, left: Expr, right: Expr):
- self.l = left
- self.r = right
-
- def emit(self):
- pass
-
-# TODO: check types if they are vectorized and emit correct intrinsic
-class Add(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" + ")
- self.r.emit()
-
-class Sub(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" - ")
- self.r.emit()
-
-class Mul(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" * ")
- self.r.emit()
-
-class Div(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" / ")
- self.r.emit()
-
-class And(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" & ")
- self.r.emit()
-
-class Xor(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" ^ ")
- self.r.emit()
-
-class GT(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" > ")
- self.r.emit()
-
-class LT(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" < ")
- self.r.emit()
-
-class GE(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" >= ")
- self.r.emit()
-
-class LE(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" <= ")
- self.r.emit()
-
-class EQ(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(f" == ")
- self.r.emit()
-
-# Loads a scalar
-class SIMDLoad(BinaryOp):
- def emit(self):
- self.l.emit()
- self.r.emit()
-
-# Loads data at address
-class SIMDLoadAt(BinaryOp):
- def emit(self):
- self.l.emit()
- emit(" = _mm256_loadu_pd(")
- self.r.emit()
- emit(")")
-
-# Assignment (stores)
-class Assign(Expr):
- def __init__(self, lhs: Expr, rhs: Expr):
- self.lhs = lhs
- self.rhs = rhs
-
-class Set(Assign):
- method = {
- (Float64x2, Ptr(Float64)) : emit·load128,
- (Float64x4, Float64x2) : emit·broadcast256,
- (Float64x4, Ptr(Float64x4)) : emit·load256,
- (Ptr(Float64x4), Float64x4) : emit·store256,
- (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256,
- }
- def emit(self):
- lhs = GetType(self.lhs)
- rhs = GetType(self.rhs)
- if (lhs, rhs) in self.method:
- self.method[(lhs, rhs)](self.lhs, self.rhs)
- else:
- self.lhs.emit()
- emit(f" = ")
- self.rhs.emit()
-
-class AddSet(Assign):
- def emit(self):
- self.lhs.emit()
- emit(f" += ")
- self.rhs.emit()
-
-class SubSet(Assign):
- def emit(self):
- self.lhs.emit()
- emit(f" -= ")
- self.rhs.emit()
-
-class MulSet(Assign):
- def emit(self):
- self.lhs.emit()
- emit(f" *= ")
- self.rhs.emit()
-
-class DivSet(Assign):
- def emit(self):
- self.lhs.emit()
- emit(f" /= ")
- self.rhs.emit()
-
-class Comma(Expr):
- def __init__(self, x: Expr, next: Expr):
- self.expr = (x, next)
-
- def emit(self):
- self.expr[0].emit()
- emit(", ")
- self.expr[1].emit()
-
-class Index(Expr):
- def __init__(self, x: Expr, i: Expr):
- self.x = x
- self.i = i
-
- def emit(self):
- self.x.emit()
- emit("[")
- self.i.emit()
- emit("]")
-
-class Paren(Expr):
- def __init__(self, x: Expr):
- self.x = x
-
- def emit(self):
- emit("(")
- self.x.emit()
- emit(")")
-
-class Call(Expr):
- def __init__(self, func: Func, args: List[Param]):
- self.func = func
- self.args = args
-
- def emit(self):
- emit(self.func.name)
- emit("(")
- if numel(self.args) > 0:
- self.args[0].emit()
- for arg in self.args[1:]:
- emit(", ")
- arg.emit()
- emit(")")
-
-def ExprList(*expr: Expr | List[Expr]):
- if numel(expr) > 1:
- x = expr[0]
- return Comma(x, ExprList(*expr[1:]))
- else:
- return expr[0]
-
-# Common assignments are kept globally
-# C statements
-
-class Stmt(Emitter):
- def emit(self):
- emit(";")
-
-class Empty(Stmt):
- pass
-
-class Block(Stmt):
- def __init__(self, *stmts: List[Stmt | Expr]):
- self.stmts = [ s if isinstance(s, Stmt) else StmtExpr(s) for s in stmts ]
-
- def emit(self):
- enter_scope()
- for i, stmt in enumerate(self.stmts):
- stmt.emit()
- if i < numel(self.stmts) - 1:
- emitln()
-
- exits_scope()
-
-class If(Stmt):
- def __init__(self, cond: Expr | List[Expr], then: Stmt | List[Stmt], orelse: Stmt | List[Stmt] | None = None):
- self.cond = cond if isinstance(cond, Expr) else ExprList(*cond)
- self.then = Block(then) if isinstance(then, list) else then
- self.orelse = Block(orelse) if isinstance(orelse, list) else orelse
-
- def emit(self):
- emit("if (")
- self.cond.emit()
- emit(") ")
- self.then.emit()
- if self.orelse is not None:
- self.orelse.emit()
-
-class For(Stmt):
- def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None):
- self.init = init if isinstance(init, Expr) else ExprList(*init)
- self.cond = cond if isinstance(cond, Expr) else ExprList(*cond)
- self.step = step if isinstance(step, Expr) else ExprList(*step)
- self.body = body if body is not None else Empty()
-
- def emit(self):
- emit("for (")
- if self.init is not None:
- self.init.emit()
- emit("; ")
- if self.cond is not None:
- self.cond.emit()
- emit("; ")
- if self.step is not None:
- self.step.emit()
- emit(") ")
-
- if isinstance(self.body, Block):
- self.body.emit()
- else:
- enter_scope()
- self.body.emit()
- exits_scope()
-
- def execute(self, stmt: Stmt | Expr):
- if isinstance(stmt, Expr):
- stmt = StmtExpr(stmt)
-
- if isinstance(self.body, Empty):
- self.body = stmt
- return
- elif not isinstance(self.body, Block):
- self.body = Block(self.body)
- self.body.stmts.append(stmt)
-
-class Return(Stmt):
- def __init__(self, val: Expr):
- self.val = val
-
- def emit(self):
- emitln()
- emit("return ")
- self.val.emit()
- super(Return, self).emit()
-
-class StmtExpr(Stmt):
- def __init__(self, x: Expr):
- self.x = x
-
- def emit(self):
- self.x.emit()
- super(StmtExpr, self).emit()
-
-class Mem(Enum):
- Auto = ""
- Static = "static"
- Register = "register"
- Typedef = "typedef"
- External = "extern"
-
-class Decl(Emitter):
- def __init__(self):
- pass
-
- def emit(self):
- pass
-
-class Func(Decl):
- def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None):
- self.name = name
- self.ret = ret
- self.params = params if params is not None else []
- self.vars = vars if vars is not None else []
- self.stmts = body if body is not None else []
-
- def emit(self):
- self.ret.emit()
- emitln()
- emit(self.name)
- emit("(")
- for i, p in enumerate(self.params):
- p.emittype()
- emit(" ")
- p.emitspec()
- if i < numel(self.params) - 1:
- emit(", ")
- emit(")\n")
-
- enter_scope()
-
- for var in self.vars:
- if isinstance(var, list):
- v = var[0]
- v.emittype()
- emit(" ")
- v.emitspec()
- for v in var[1:]:
- emit(", ")
- v.emitspec()
- else:
- var.emittype()
- emit(" ")
- var.emitspec()
-
- emit(";")
- emitln()
-
- if numel(self.vars) > 0:
- emitln()
-
- for stmt in self.stmts[:-1]:
- stmt.emit()
- emitln()
- if numel(self.stmts) > 0:
- self.stmts[-1].emit()
-
- exits_scope()
- emitln(2)
-
- def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]:
- if var.name in [v.name for v in self.vars]:
- return
-
- self.vars.append(var)
- if numel(vars) == 0:
- return Ident(var)
-
- self.vars.extend(vars)
-
- idents = [Ident(var)]
- idents += [Ident(v) for v in vars]
- return idents
-
- def execute(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]):
- def push(n):
- if isinstance(n, Stmt):
- self.stmts.append(n)
- elif isinstance(n, Expr):
- self.stmts.append(StmtExpr(n))
- else:
- raise TypeError("unrecognized type for function")
- push(stmt)
- for arg in args:
- push(arg)
-
- def variables(self, *idents: List[str]) -> List[Expr]:
- vars = {v.name : v for v in self.vars + self.params}
-
- if numel(idents) == 1:
- return Ident(vars[idents[0]])
- return [Ident(vars[ident]) for ident in idents]
-
-class Var(Decl):
- def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto):
- self.name = name
- self.type = type
- self.storage = storage
-
- def emit(self):
- emit(f"{self.name}")
-
- def emittype(self):
- if self.storage != Mem.Auto:
- emit(self.storage.value)
- emit(" ")
- self.type.emit()
-
- def emitspec(self):
- self.type.emitspec(self.name)
-
- @classmethod
- def copy(cls, other: Var):
- return cls(other.type, other.name, other.storage)
-
-class Param(Var):
- def __init__(self, type: Type, name: str):
- return super(Param, self).__init__(type, name, Mem.Auto)
-
- @classmethod
- def copy(cls, other: Param):
- return cls(other.type, other.name)
-
-def Params(*ps: List[Tuple(Type, str)]) -> List[Param]:
- return [Param(p[0], p[1]) for p in ps]
-
-def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]:
- return [Var(p[0], p[1]) for p in ps]
-
-# ------------------------------------------------------------------------
-# AST modification/production functions
-
-# ------------------------------------------
-# basic (non-recursive) commands
-
-def Swap(x: Var, y: Var, tmp: Var) -> Stmt:
- return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp)))
-
-def IsLoop(s: Stmt) -> bool:
- return type(s) == For
-
-def EvenTo(x: Var, n: int) -> Var:
- return And(x, Negate(I(n-1)))
-
-# ------------------------------------------
-# Expand: takes statement, indexed, and expands it x times
-
-# def Expand(s: Stmt, times: int) -> Block:
-# if not isinstance(s, StmtExpr):
-# raise TypeError(f"{type(x)} not supported by Expand operation")
-
-# return Block(StmtExpr(Step(s.x, i)) for i in range(times))
-
-# ------------------------------------------
-# repeat: takes command on an array and repeats it
-
-@singledispatch
-def Repeat(x: Expr, times: int) -> Expr | List[Expr]:
- raise TypeError(f"{type(x)} not supported by Repeat operation")
-
-@Repeat.register
-def _(x: Inc, times: int):
- return AddSet(x.x, I(times))
-
-@Repeat.register
-def _(x: Dec, times: int):
- return DecSet(x.x, I(times))
-
-@Repeat.register
-def _(x: Comma, times: int):
- return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times))
-
-@singledispatch
-def Repeat(x: Expr, times: int) -> Expr | List[Expr]:
- raise TypeError(f"{type(x)} not supported by Repeat operation")
-
-@Repeat.register
-def _(x: Inc, times: int):
- return AddSet(x.x, I(times))
-
-@Repeat.register
-def _(x: Dec, times: int):
- return DecSet(x.x, I(times))
-
-@Repeat.register
-def _(x: Comma, times: int):
- return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times))
-
-# ------------------------------------------
-# step: indexes an expression by i
-
-@singledispatch
-def Step(x: Expr, i: int) -> Expr:
- raise TypeError(f"{type(x)} not supported by Step operation")
-
-@Step.register
-def _(x: Comma, i: int):
- return Comma(Step(x.expr[0], i), Step(x.expr[1], i))
-
-@Step.register
-def _(x: Assign, i: int):
- return type(x)(Step(x.lhs, i), Step(x.rhs, i))
-
-@Step.register
-def _(x: BinaryOp, i: int):
- return type(x)(Step(x.l, i), Step(x.r, i))
-
-@Step.register
-def _(x: UnaryOp, i: int):
- return type(x)(Step(x.x, i))
-
-@Step.register
-def _(x: Ident, i: int):
- return x
-
-@Step.register
-def _(x: Deref, i: int):
- return Index(x.x, I(i))
-
-@Step.register
-def _(ix: Index, i: int):
- return Index(ix.x, Add(ix.i, I(i)))
-
-# ------------------------------------------
-# bfs search on statements in ast
-
-@singledispatch
-def Visit(s: Stmt, func):
- raise TypeError(f"{type(s)} not supported by Visit operation")
-
-@Visit.register
-def _(s: Empty, func):
- return
-
-@Visit.register
-def _(blk: Block, func):
- for stmt in blk.stmts:
- Visit(stmt, func)
-
-@Visit.register
-def _(jmp: If, func):
- func(jmp.cond)
- Visit(jmp.then, func)
- if jmp.orelse is not None:
- Visit(jmp.orelse, func)
-
-@Visit.register
-def _(loop: For, func):
- func(loop.init)
- func(loop.cond)
- func(loop.step)
- Visit(loop.body, func)
-
-@Visit.register
-def _(ret: Return, func):
- func(ret.val)
-
-@Visit.register
-def _(x: StmtExpr, func):
- func(x.x)
-
-# ------------------------------------------
-# recreates a piece of an AST, allowing an arbitrary transformation of expression nodes
-
-@singledispatch
-def Make(s: Stmt, func):
- raise TypeError(f"{type(s)} not supported by Make operation")
-
-@Make.register
-def _(s: Empty, func):
- return Empty()
-
-@Make.register
-def _(blk: Block, func):
- return Block(*(Make(stmt, func) for stmt in blk.stmts))
-
-@Make.register
-def _(loop: For, func):
- return For(func(loop.init), func(loop.cond), func(loop.step), Make(loop.body, func))
-
-@Make.register
-def _(jmp: If, func):
- if jmp.orelse is not None:
- return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func))
- else:
- return If(func(jmp.cond), Make(jmp.then, func))
-
-@Make.register
-def _(ret: Return, func):
- return Return(func(ret.val))
-
-@Make.register
-def _(x: StmtExpr, func):
- return StmtExpr(func(x.x))
-
-# ------------------------------------------
-# GetType function
-
-@singledispatch
-def GetType(x: Expr) -> Type:
- if x is None:
- return
- raise TypeError(f"{type(x)} not supported by GetType operation")
-
-@GetType.register
-def _(x: Empty):
- return Void
-
-@GetType.register
-def _(x: Empty):
- return Void
-
-@GetType.register
-def _(x: S):
- return Ptr(Byte)
-
-@GetType.register
-def _(x: I):
- return Int
-
-@GetType.register
-def _(x: F):
- return Float64
-
-@GetType.register
-def _(x: Ident):
- return x.var.type
-
-@GetType.register
-def _(x: UnaryOp):
- return GetType(x.x)
-
-@GetType.register
-def _(x: Deref):
- return GetType(x.x).to
-
-@GetType.register
-def _(x: Ref):
- return Ptr(GetType(x.x))
-
-@GetType.register
-def _(x: Index):
- base = GetType(x.x)
- if isinstance(base, Ptr):
- return base.to
- elif isinstance(base, Array):
- return base.base
- else:
- pass
- # raise TypeError(f"attempting to index type {base} of node {x.x}")
-
-# TODO: type checking for both
-
-@GetType.register
-def _(x: BinaryOp):
- lhs = GetType(x.l)
- rhs = GetType(x.r)
- return lhs
-
-@GetType.register
-def _(x: Var):
- return x.type
-
-@GetType.register
-def _(x: Assign):
- lhs = GetType(x.lhs)
- rhs = GetType(x.rhs)
- return lhs
-
-@GetType.register
-def _(x: Comma):
- return GetType(x.expr[1])
-
-@GetType.register
-def _(x: Paren):
- return GetType(x.x)
-
-@GetType.register
-def _(x: Call):
- return x.func.ret
-
-# ------------------------------------------
-# Transform function
-# Recurses down AST, and modifies Expr nodes
-# Returns a brand new tree structure
-
-@singledispatch
-def Transform(x: Expr, func: Callable[Expr, Expr]):
- raise TypeError(f"{type(x)} not supported by Transform operation")
-
-@Transform.register
-def _(x: Empty, func):
- return func(Empty())
-
-@Transform.register
-def _(x: Literal, func):
- return func(type(x)(x))
-
-@Transform.register
-def _(x: Ident, func):
- return func(Ident(x.var))
-
-@Transform.register
-def _(x: UnaryOp, func):
- return func(UnaryOp(Transform(x.x, func)))
-
-@Transform.register
-def _(x: Assign, func):
- return func(type(x)(Transform(x.lhs, func), Transform(x.rhs, func)))
-
-@Transform.register
-def _(x: Deref, func):
- return func(Transform(Deref(x.x, func)))
-
-@Transform.register
-def _(x: Index, func):
- return func(Index(Transform(x.x, func), Transform(x.i, func)))
-
-@Transform.register
-def _(x: Comma, func):
- return func(Comma(Transform(x.expr[0], func), Transform(x.expr[1], func)))
-
-@Transform.register
-def _(x: BinaryOp, func):
- return func(type(x)(Transform(x.l, func), Transform(x.r, func)))
-
-# ------------------------------------------
-# Filter function
-# Recurses down AST, and stores Expr nodes that satisfy condition in results
-
-@singledispatch
-def Filter(x: Expr, cond, results: List[Expr]):
- raise TypeError(f"{type(x)} not supported by Filter operation")
-
-@Filter.register
-def _(x: Empty, cond, results: List[Expr]):
- if cond(x):
- results.append(x)
-
-@Filter.register(Ident)
-@Filter.register(Literal)
-def _(x, cond, results: List[Expr]):
- if cond(x):
- results.append(x)
-
-@Filter.register
-def _(x: UnaryOp, cond, results: List[Expr]):
- if cond(x):
- results.append(x)
- Filter(op.x, cond, results)
-
-@Filter.register
-def _(s: Assign, cond, results: List[Expr]):
- if cond(s):
- results.append(s)
- Filter(s.lhs, cond, results)
- Filter(s.rhs, cond, results)
-
-@Filter.register
-def _(v: Deref, cond, results: List[Expr]):
- if cond(v):
- results.append(v)
- Filter(v.x, cond, results)
-
-@Filter.register
-def _(i: Index, cond, results: List[Expr]):
- if cond(i):
- results.append(i)
- Filter(i.x, cond, results)
- Filter(i.i, cond, results)
-
-@Filter.register
-def _(comma: Comma, cond, results: List[Expr]):
- if cond(comma):
- results.append(comma)
- Filter(comma.expr[0], cond, results)
- Filter(comma.expr[1], cond, results)
-
-@Filter.register
-def _(op: BinaryOp, cond, results: List[Expr]):
- if cond(op):
- results.append(op)
- Filter(op.l, cond, results)
- Filter(op.r, cond, results)
-
-# ------------------------------------------
-# Eval function
-# Recurses down AST, and evaluates Expr nodes
-# Throws an error if tree can't be evaluated at compile time
-
-@singledispatch
-def Eval(x: Expr) -> object:
- raise TypeError(f"{type(x)} not supported by Eval operation")
-
-@Eval.register
-def _(x: Empty):
- pass
-
-@Eval.register(Ident)
-@Eval.register(S)
-def _(x):
- return sympy.symbols(f"{x}")
-
-@Eval.register(float)
-@Eval.register(int)
-@Eval.register(I)
-@Eval.register(F)
-def _(x):
- return x
-
-@Eval.register
-def _(x: Inc):
- return Eval(Add(Eval(op.x, cond, results), I(1)))
-
-@Eval.register
-def _(x: Dec):
- return Eval(Sub(Eval(op.x, cond, results), I(1)))
-
-# TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x))
-@Eval.register
-def _(op: Add):
- l = Eval(op.l)
- r = Eval(op.r)
- return sympy.simplify(l + r)
-
-@Eval.register
-def _(op: Sub):
- l = Eval(op.l)
- r = Eval(op.r)
- return sympy.simplify(l - r)
-
-@Eval.register
-def _(op: Mul):
- l = Eval(op.l)
- r = Eval(op.r)
- return sympy.simplify(l * r)
-
-@Eval.register
-def _(op: Div):
- l = Eval(op.l)
- r = Eval(op.r)
- return sympy.simplify(l / r)
-
-@Eval.register
-def _(s: Assign):
- return Eval(s.rhs)
-
-@Eval.register
-def _(comma: Comma):
- return Eval(comma.expr[1])
-
-# ------------------------------------------
-# Leaf traversal
-
-def VarsUsed(stmt: Stmt) -> List[Var]:
- vars = []
- Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars))
- vars = set([v.var for v in vars])
-
- return vars
-
-def AddrAccessed(stmt: Stmt):
- scalars = [] # variables that are accessed as scalars (single sites)
- vectors = {} # variables that are accessed as vectors (indexed/dereferenced)
- nodes = []
- Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), scalars))
- Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Index), nodes))
-
- for node in nodes:
- vars = []
- Filter(node.x, lambda x: isinstance(x, Ident), vars)
- if numel(vars) != 1:
- raise ValueError("multiple variables used in index expression not supported")
- vectors[vars[0]] = node.i
- if vars[0] in scalars:
- scalars.remove(vars[0])
-
- return set(scalars), vectors
-
-# ------------------------------------------
-# Large scale functions
-
-def Unroll(name: str, loop: For, times: int, ret: Ident = None, accumulator = None) -> (Func, List[Stmt]):
- # TODO: More sophisticated computation for length of loop
- if not isinstance(loop.cond, LE) and not isinstance(loop.cond, LT):
- raise TypeError(f"{type(loop.cond)} not supported in loop unrolling")
-
- # pull off needed features of the loop
- i = loop.init.lhs.var
- vars = VarsUsed(loop.body)
-
- def asvector(v):
- if v != i:
- return Var(Array(v.type, times), v.name, v.storage)
- else:
- return Var.copy(v)
-
- n = loop.cond.r.var
- param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if (type(v) == Param and v != n)]
- stack = {v.name: asvector(v) for v in vars if type(v) == Var}
-
- # TODO: More sophisticated type checking
- if (type(n) != Param):
- raise TypeError(f"{type(n)} not implemented yet")
-
- if ret is None:
- kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values()))
- else:
- kernel = Func(f"{name}_kernel{times}", ret.var.type, param, list(stack.values()))
-
- len = kernel.declare(n)
-
- body = loop.body
- itor, itorp = kernel.variables("i", "ip")
-
- def mkarray(x: Expr, times: int):
- if isinstance(x, Ident):
- if type(x.var) == Var:
- x.var = stack[x.name]
- return x
-
- def step(x: Expr, times: int):
- if x == itor:
- return Add(x, I(times))
- if isinstance(x, Ident):
- if type(x.var) == Var:
- return Index(x, I(times))
- if not isinstance(x, Expr):
- raise ValueError(f"panic, hit type {type(x)}")
-
- return x
-
- expandedloop = Make(loop.body, lambda node: Transform(node, lambda x: mkarray(x, times)))
- kernel.execute(
- Set(len, EvenTo(Deref(itorp), times)),
- For(Set(itor, I(0)), LT(itor, len), Repeat(loop.step, times),
- body = Block(*
- (Make(expandedloop, lambda node:
- Transform(node, lambda x: step(x, i))) for i in range(times)
- )
- )
- ),
- Set(Deref(itorp), itor)
- )
-
- if ret is not None:
- k = kernel.variables(ret.name)
- if k.var.type != ret.var.type:
- if accumulator is None:
- raise ValueError("If loop returns a value, an accumulator must be given")
- else:
- accumulator = For(Set(itor, I(1)), LT(itor, I(times)), Inc(itor),
- body = accumulator(I(0), itor, k)
- )
- kernel.execute(accumulator)
-
- kernel.execute(Return(Index(ret, I(0))))
-
- loop.init = None
- if ret is None:
- return kernel, [Set(itor, len), Call(kernel, [Ref(itor)] + param[1:]), loop]
- else:
- return kernel, [Set(itor, len), Set(ret, Call(kernel, [Ref(itor)] + param[1:])), loop]
-
-# Replaces all vectorizable loops inside Func with vectorized variants
-# Returns the new vectorized function (tagged by the SIMD chosen)
-
-class SymTab(object):
- def __init__(self):
- self.stack = {}
- self.addrs = {}
-
- def have(self, name):
- return name in self.stack or name in self.addrs
-
-def Vectorize(func: Func, isa: SIMD) -> Func:
- if isa != SIMD.AVX2:
- raise ValueError(f"ISA '{isa}' not currently implemented")
-
- vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params)
- for stmt in func.stmts:
- if IsLoop(stmt):
- loop = For(stmt.init, stmt.cond, stmt.step)
- iterator = set([stmt.init.lhs])
-
- body = stmt.body
- if type(body) != Block:
- # TODO: Think through this carefully...
- # This is coded for the accumulation step at the end of a kernel
- loop.cond.r = I(Eval(Div(loop.cond.r, 4)))
- loop.body = body
- else:
- instr = body.stmts
- # TODO: Remove hardcoded 4 -> should be function of types!
- # As of now, we have harcoded AVX2 <-> float64 relationship
- if numel(instr) % 4 != 0:
- raise ValueError("loop can not be vectorized, instructions can not be globbed equally")
-
- # TODO: Allow for non-sequential accesses?
- for i in range(0, numel(instr), 4):
- scalars = []
- vectors = []
- for j in range(4):
- s, v = AddrAccessed(instr[i+j])
- s -= iterator # TODO: This is hacky
- scalars.append(frozenset(s))
- vectors.append(v)
-
- # Test if code in uniform to allow for vectorization
- if numel(set(scalars)) != 1:
- raise ValueError("non uniform scalar accesses in consecutive line. can not vectorize")
-
- for j in range(1, 4):
- for v, idx in vectors[j].items():
- if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1:
- print(f"{delta}")
- raise ValueError("non uniform vector accesses in consecutive line. can not vectorize")
-
- # If we made it to here, we have passed all checks. vectorize!
- vecs = vectors[0]
- if i == 0:
- syms = SymTab()
- for s in scalars[0]:
- intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128"))
- syms.stack[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
- vfunc.execute(Set(intermediate, Ref(s)))
- vfunc.execute(Set(syms.stack[s.name], intermediate))
-
- for v in vecs.keys():
- # All params are treated AS addresses to load into vectorized registers
- if type(v.var) == Param:
- syms.addrs[v.name] = Var(Ptr(Float64x4), f"{v.name}256")
- # All stack declared parameters MUST be moved to vectorized registers
- else:
- assert IsArrayType(v.var.type), f"must be an array type, instead got {v.var.type}"
- nreg = v.var.type.len // 4
- if nreg > 1:
- syms.stack[v.name] = vfunc.declare(Var(Array(Float64x4, nreg), f"{v.name}256"))
- else:
- syms.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256"))
-
-
- # IMPORTANT: We do a post-order traversal.
- # We transforms leaves (identifiers) first and then move back up the root
- def translate(x):
- if isinstance(x, Ident):
- if x.name in syms.stack:
- return syms.stack[x.name]
- elif x.name in syms.addrs:
- x.var = syms.addrs[x.name]
- if isinstance(x, Index):
- if type(x.x) == Ident:
- if x.x.name in syms.addrs:
- return Add(x.x, x.i)
- elif f"{x.x.name[:-3]}" in syms.stack: #NOTE: This is hacky. Think of something better
- return Index(x.x, I(Eval(Div(x.i, 4))))
- if isinstance(x, BinaryOp):
- l, r = GetType(x.l), GetType(x.r)
- if IsVectorType(l) or IsVectorType(r):
- if type(l) == Ptr:
- x.l = Deref(x.l)
- if type(r) == Ptr:
- x.r = Deref(x.r)
-
- return x
-
- loop.execute(Make(instr[i], lambda node: Transform(node, translate)))
-
- vfunc.execute(loop)
- else:
- vfunc.execute(stmt)
-
- return vfunc
-
-
-def Strided(func: Func) -> Func:
- pass
-
-# ------------------------------------------------------------------------
-# Point of testing
-
-def copy():
- F = Func("blas·copy", Void,
- Params(
- (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
- ),
- Vars(
- (Int, "i"), (Float64, "tmp")
- )
- )
- # could also declare like
- # F.declare( ... )
-
- len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Swap(Index(x, i), Index(y, i), tmp)
- )
-
- kernel, calls = Unroll("copy", loop, 8)
- kernel.emit()
-
- avx256kernel = Vectorize(kernel, SIMD.AVX2)
- avx256kernel.emit()
-
- F.execute(*calls)
- F.emit()
-
-def axpby():
- F = Func("blas·axpby", Void,
- Params(
- (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y")
- ),
- Vars(
- (Int, "i"), #(Float64, "tmp")
- )
- )
- # could also declare like
- # F.declare( ... )
-
- # TODO: Increase ergonomics here...
- len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- Set(Index(y, i),
- Mul(b,
- Add(Index(y, i),
- Mul(a, Index(x, i)),
- )
- )
- )
- )
-
- kernel, calls = Unroll("axpby", loop, 8)
- kernel.emit()
-
- avx256kernel = Vectorize(kernel, SIMD.AVX2)
- avx256kernel.emit()
-
- F.execute(*calls)
- F.emit()
-
-def argmax():
- F = Func("blas·argmax", Int,
- Params(
- (Int, "len"), (Ptr(Float64), "x"),
- ),
- Vars(
- (Int, "i"), (Int, "ix"), (Float64, "max")
- )
- )
- len, x, i, ix, max = F.variables("len", "x", "i", "ix", "max")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- If(GT(Index(x, i), max),
- Block(
- Set(ix, i),
- Set(max, Index(x, ix)),
- )
- )
- )
- kernel, calls = Unroll("argmax", loop, 8, ix,
- lambda ires, icur, node:
- If(GT(Index(x, Index(node, icur)), Index(x, Index(node, ires))),
- Block(Set(Index(node, ires), Index(node, icur)))
- )
- )
- kernel.emit()
-
- # avx256kernel = Vectorize(kernel, SIMD.AVX2)
- # avx256kernel.emit()
-
- F.execute(*calls, Return(ix))
-
- F.execute(loop)
- F.emit()
-
-def dot():
- F = Func("blas·dot", Float64,
- Params(
- (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
- ),
- Vars(
- (Int, "i"), (Float64, "sum"),
- )
- )
- len, x, i, y, sum = F.variables("len", "x", "i", "y", "sum")
-
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- loop.execute(
- AddSet(sum, Mul(Index(x, i), Index(y, i)))
- )
-
- kernel, calls = Unroll("dot", loop, 16, sum,
- lambda ires, icur, node:
- StmtExpr(AddSet(Index(node, ires), Index(node, icur)))
- )
- kernel.emit()
-
- avx256kernel = Vectorize(kernel, SIMD.AVX2)
- avx256kernel.emit()
-
- F.execute(*calls, Return(sum))
-
- F.emit()
-
-def gemv():
- F = Func("blas·gemv", Void,
- Params(
- (Int, "nrow"), (Int, "ncol"), (Float64, "a"), (Ptr(Float64), "m"), (Int, "incm"),
- (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y")
- ),
- Vars(
- (Int, "r"), (Int, "c"), (Ptr(Float64), "row"), (Float64, "res")
- )
- )
- r, c, row, res = F.variables("r", "c", "row", "res")
- nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
- loop = For(Set(r, I(0)), LT(r, nrow), Inc(r),
- Block(
- Set(row, Add(m, incm)),
- Set(res, I(0)),
- For(Set(c, I(0)), LT(c, ncol), Inc(c),
- AddSet(res, Mul(Index(row, c), Index(x, c)))
- ),
- Set(Index(y, r), Add(Mul(a, res), Mul(b, Index(y, r))))
- )
- )
-
- F.execute(loop)
- F.emit()
-
-
-if __name__ == "__main__":
- emitheader()
-
- gemv()
- # dot()
- # argmax()
- # copy()
- # axpby()
-
- print(buffer)