aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-11 17:52:07 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-11 17:52:07 -0700
commit2f4de9abb28ecb5af7390e7011a40f493c4b4dae (patch)
tree6a18b2dbe27052e9be41ed7c2bf0cf00ace01c78
parenta023f63a6ff79dc7e5aa18771062d333e62a7411 (diff)
feat: begun adding lookup table of functions
-rw-r--r--lib/c.py481
1 files changed, 416 insertions, 65 deletions
diff --git a/lib/c.py b/lib/c.py
index 25c97b2..070ade6 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -5,10 +5,11 @@ import os
import sys
from enum import Enum
-from typing import Set, List, Tuple, Dict
+from typing import List, Tuple, Dict
from functools import singledispatch
numel = len
+
# ------------------------------------------------------------------------
# String buffer
@@ -53,33 +54,6 @@ class Emitter(object):
# ------------------------------------------
# Representation of a C type
-class TypeKind(Enum):
- Void = "void"
- Error = "error"
-
- Int = "int"
- Int32 = "int32"
- Int64 = "int64"
-
- # vectorized variants
- Int32x4 = "__mm128i"
- Int32x8 = "__mm256i"
- Int64x2 = "__mm128i"
- Int64x4 = "__mm256i"
-
- Float32 = "float"
- Float64 = "double"
-
- # vectorized variants
- Float32x4 = "__m128"
- Float32x8 = "__mm256"
- Float64x2 = "__mm128d"
- Float64x4 = "__mm256d"
-
- Pointer = "pointer"
- Struct = "struct"
- Enum = "enum"
- Union = "union"
class Type(Emitter):
def emit(self):
@@ -92,16 +66,71 @@ class Base(Type):
def __init__(self, name: str):
self.name = name
+ def __hash__(self):
+ return hash(self.name)
+
+ def __eq__(self, other):
+ return type(self) == type(other) and self.name == other.name
+
+ def __str__(self):
+ return self.name
+
def emit(self):
emit(self.name)
def emitspec(self, ident):
emit(f"{ident}")
+# TODO: Operation lookup tables...
+
+class Ptr(Type):
+ def __init__(self, to: Type):
+ self.to = to
+
+ def __str__(self):
+ return f"*{self.to}"
+
+ def __eq__(self, other):
+ return type(self) == type(other) and self.to == other.to
+
+ def __hash__(self):
+ return hash((hash(self.to), "*"))
+
+ def emit(self):
+ self.to.emit()
+
+ def emitspec(self, ident):
+ emit("*")
+ self.to.emitspec(ident)
+
+class Array(Type):
+ def __init__(self, base: Type, len: int):
+ self.base = base
+ self.len = len
+
+ def __str__(self):
+ return f"[{self.len}]{self.base}"
+
+ def __eq__(self, other):
+ return type(self) == type(other) and self.base == other.base
+
+ def __hash__(self):
+ return hash((hash(self.base), self.len))
+
+ def emit(self):
+ self.base.emit()
+
+ def emitspec(self, ident):
+ self.base.emitspec(ident)
+ emit(f"[{self.len}]")
+
# Machine primitive types
Void = Base("void")
Error = Base("error")
+Byte = Base("byte")
+String = Ptr(Byte)
+
Int = Base("int")
Int8 = Base("int8")
Int16 = Base("int16")
@@ -119,6 +148,18 @@ Float32x8 = Base("__mm256")
Float64x2 = Base("__m128d")
Float64x4 = Base("__mm256d")
+def VectorType(kind):
+ if kind is Float32x4 or \
+ kind is Float32x8 or \
+ kind is Float64x2 or \
+ kind is Float64x4 or \
+ kind is Int32x4 or \
+ kind is Int32x8 or \
+ kind is Int64x2 or \
+ kind is Int64x4:
+ return True
+ return False
+
# TODO: Make this ARCH dependent
BitDepth= {
Void: 8,
@@ -155,31 +196,6 @@ SIMDSupport = {
SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
}
-# TODO: Operation lookup tables...
-
-class Ptr(Type):
- def __init__(self, to: Type):
- self.to = to
-
- def emit(self):
- self.to.emit()
-
- def emitspec(self, ident):
- emit("*")
- self.to.emitspec(ident)
-
-class Array(Type):
- def __init__(self, base: Type, len: int):
- self.base = base
- self.len = len
-
- def emit(self):
- self.base.emit()
-
- def emitspec(self, ident):
- self.base.emitspec(ident)
- emit(f"[{self.len}]")
-
# TODO: Typedefs...
# ------------------------------------------
@@ -309,7 +325,6 @@ class Xor(BinaryOp):
emit(f" ^ ")
self.r.emit()
-
class GT(BinaryOp):
def emit(self):
self.l.emit()
@@ -340,17 +355,64 @@ class EQ(BinaryOp):
emit(f" == ")
self.r.emit()
+# Loads a scalar
+class SIMDLoad(BinaryOp):
+ def emit(self):
+ self.l.emit()
+ self.r.emit()
+
+# Loads data at address
+class SIMDLoadAt(BinaryOp):
+ def emit(self):
+ self.l.emit()
+ emit(" = _mm256_loadu_pd(")
+ self.r.emit()
+ emit(")")
+
# Assignment (stores)
class Assign(Expr):
def __init__(self, lhs: Expr, rhs: Expr):
self.lhs = lhs
self.rhs = rhs
+def emit·load256(l, r):
+ emit("_mm256_loadu_pd(")
+ l.emit()
+ emit(", ")
+ r.emit()
+ emit(")")
+
+def emit·store256(l, r):
+ emit("_mm256_storeu_pd(")
+ l.emit()
+ emit(", ")
+ r.emit()
+ emit(")")
+
+def emit·copy256(l, r):
+ emit("_mm256_storeu_pd(")
+ l.emit()
+ emit(", ")
+ emit("_mm256_loadu_pd(")
+ r.emit()
+ emit("))")
+
+SetTypeDispatch = {
+ (Float64x4, Ptr(Float64x4)) : emit·load256,
+ (Ptr(Float64x4), Float64x4) : emit·store256,
+ (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256,
+}
+
class Set(Assign):
def emit(self):
- self.lhs.emit()
- emit(f" = ")
- self.rhs.emit()
+ lhs = GetType(self.lhs)
+ rhs = GetType(self.rhs)
+ if (lhs, rhs) in SetTypeDispatch:
+ SetTypeDispatch[(lhs, rhs)](self.lhs, self.rhs)
+ else:
+ self.lhs.emit()
+ emit(f" = ")
+ self.rhs.emit()
class AddSet(Assign):
def emit(self):
@@ -479,6 +541,17 @@ class For(Stmt):
self.body.emit()
exits_scope()
+ def execute(self, stmt: Stmt | Expr):
+ if isinstance(stmt, Expr):
+ stmt = StmtExpr(stmt)
+
+ if isinstance(self.body, Empty):
+ self.body = stmt
+ return
+ elif not isinstance(self.body, Block):
+ self.body = Block([self.body])
+ self.body.stmts.append(stmt)
+
class Return(Stmt):
def __init__(self, val: Expr):
self.val = val
@@ -563,6 +636,9 @@ class Func(Decl):
exits_scope()
def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]:
+ if var.name in [v.name for v in self.vars]:
+ return
+
self.vars.append(var)
if numel(vars) == 0:
return Ident(var)
@@ -727,7 +803,163 @@ def _(x: StmtExpr, func):
func(x.x)
# ------------------------------------------
-# expression functions
+# recreates a piece of an AST, allowing an arbitrary transformation of expression nodes
+
+@singledispatch
+def Make(s: Stmt, func):
+ raise TypeError(f"{type(s)} not supported by Make operation")
+
+@Make.register
+def _(s: Empty, func):
+ return Empty()
+
+@Make.register
+def _(blk: Block, func):
+ return Block([func(stmt) for stmt in blk.Stmts])
+
+@Make.register
+def _(loop: For, func):
+ return For(func(loop.init), func(loop.cond), func(loop.step), func(loop.body))
+
+@Make.register
+def _(ret: Return, func):
+ return Return(func(ret.val))
+
+@Make.register
+def _(x: StmtExpr, func):
+ return StmtExpr(func(x.x))
+
+# ------------------------------------------
+# GetType function
+
+@singledispatch
+def GetType(x: Expr) -> Type:
+ raise TypeError(f"{type(x)} not supported by GetType operation")
+
+@GetType.register
+def _(x: Empty):
+ return Void
+
+@GetType.register
+def _(x: Empty):
+ return Void
+
+@GetType.register
+def _(x: S):
+ return Ptr(Byte)
+
+@GetType.register
+def _(x: I):
+ return Int
+
+@GetType.register
+def _(x: F):
+ return Float64
+
+@GetType.register
+def _(x: Ident):
+ return x.var.type
+
+@GetType.register
+def _(x: UnaryOp):
+ return GetType(x.x)
+
+@GetType.register
+def _(x: Deref):
+ return GetType(x.x).to
+
+@GetType.register
+def _(x: Ref):
+ return Ptr(GetType(x.x))
+
+@GetType.register
+def _(x: Index):
+ base = GetType(x.x)
+ if isinstance(base, Ptr):
+ return base.to
+ elif isinstance(base, Array):
+ return base.base
+ else:
+ raise TypeError(f"attempting to index type {base}")
+
+# TODO: type checking for both
+
+@GetType.register
+def _(x: BinaryOp):
+ lhs = GetType(x.l)
+ rhs = GetType(x.r)
+ return lhs
+
+@GetType.register
+def _(x: Var):
+ return x.type
+
+@GetType.register
+def _(x: Assign):
+ lhs = GetType(x.lhs)
+ rhs = GetType(x.rhs)
+ return lhs
+
+@GetType.register
+def _(x: Comma):
+ return GetType(x.expr[1])
+
+@GetType.register
+def _(x: Paren):
+ return GetType(x.x)
+
+@GetType.register
+def _(x: Call):
+ return x.func.ret
+
+# ------------------------------------------
+# Transform function
+# Recurses down AST, and modifies Expr nodes
+# Returns a brand new tree structure
+
+@singledispatch
+def Transform(x: Expr, func):
+ raise TypeError(f"{type(x)} not supported by Transform operation")
+
+@Transform.register
+def _(x: Empty, func):
+ return func(Empty())
+
+@Transform.register
+def _(x: Literal, func):
+ return func(type(x)(x))
+
+@Transform.register
+def _(x: Ident, func):
+ return func(Ident(x.var))
+
+@Transform.register
+def _(x: UnaryOp, func):
+ return func(UnaryOp(Transform(x.x, func)))
+
+@Transform.register
+def _(x: Assign, func):
+ return func(type(x)(Transform(x.lhs, func), Transform(x.rhs, func)))
+
+@Transform.register
+def _(x: Deref, func):
+ return func(Transform(Deref(x.x, func)))
+
+@Transform.register
+def _(x: Index, func):
+ return func(Index(Transform(x.x, func), Transform(x.i, func)))
+
+@Transform.register
+def _(x: Comma, func):
+ return func(Comma(Transform(x.expr[0], func), Transform(x.expr[1], func)))
+
+@Transform.register
+def _(x: BinaryOp, func):
+ return func(BinaryOp(Transform(x.l, func), Transform(x.r, func)))
+
+# ------------------------------------------
+# Filter function
+# Recurses down AST, and stores Expr nodes that satisfy condition in results
@singledispatch
def Filter(x: Expr, cond, results: List[Expr]):
@@ -784,13 +1016,82 @@ def _(op: BinaryOp, cond, results: List[Expr]):
Filter(op.l, cond, results)
Filter(op.r, cond, results)
-def VarsUsed(stmt: Stmt) -> Set[Var]:
+# ------------------------------------------
+# Eval function
+# Recurses down AST, and evaluates Expr nodes
+# Throws an error if tree can't be evaluated at compile time
+
+@singledispatch
+def Eval(x: Expr) -> object:
+ raise TypeError(f"{type(s)} not supported by Eval operation")
+
+@Eval.register
+def _(x: Empty):
+ pass
+
+@Eval.register(Ident)
+@Eval.register(Literal)
+def _(x):
+ return x
+
+@Eval.register
+def _(x: Inc):
+ return Eval(Add(Eval(op.x, cond, results), I(1)))
+
+@Eval.register
+def _(x: Dec):
+ return Eval(Sub(Eval(op.x, cond, results), I(1)))
+
+# TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x))
+@Eval.register
+def _(op: Add):
+ l = Eval(op.l)
+ r = Eval(op.r)
+ return l + r
+
+@Eval.register
+def _(op: Sub):
+ l = Eval(op.l)
+ r = Eval(op.r)
+ # TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x))
+ return l - r
+
+@Eval.register
+def _(s: Assign):
+ return Eval(s.rhs)
+
+@Eval.register
+def _(comma: Comma):
+ return Eval(comma.expr[1])
+
+# ------------------------------------------
+# Leaf traversal
+
+def VarsUsed(stmt: Stmt) -> List[Var]:
vars = []
- Visit(loop.body, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars))
+ Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars))
vars = set([v.var for v in vars])
return vars
+def AddrAccessed(stmt: Stmt):
+ scalars = [] # variables that are accessed as scalars (single sites)
+ vectors = {} # variables that are accessed as vectors (indexed/dereferenced)
+ nodes = []
+ Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), scalars))
+ Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Index), nodes))
+
+ for node in nodes:
+ vars = []
+ Filter(node.x, lambda x: isinstance(x, Ident), vars)
+ if numel(vars) != 1:
+ raise ValueError("multiple variables used in index expression not supported")
+ vectors[vars[0]] = node.i
+ if vars[0] in scalars:
+ scalars.remove(vars[0])
+
+ return frozenset(scalars), vectors
+
# ------------------------------------------
# Large scale functions
@@ -841,13 +1142,63 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
print("could not vectorize loop, skipping")
loop.body = body
else:
- instrs = body.stmts
- # TODO: Remove hardcoded 4
- if numel(instrs) % 4 != 0:
+ instr = body.stmts
+ # TODO: Remove hardcoded 4 -> should be function of types!
+ # As of now, we have harcoded AVX2 <-> float64 relationship
+ if numel(instr) % 4 != 0:
raise ValueError("loop can not be vectorized, instructions can not be globbed equally")
- # TODO: Allow for non-sequential accesses?
- # for i in range(numel(instrs)/4):
+ # TODO: Allow for non-sequential accesses?
+ s_symtab, v_symtab = {}, {}
+ for i in range(0, numel(instr), 4):
+ scalars = []
+ vectors = []
+ for j in range(4):
+ s, v = AddrAccessed(instr[i+j])
+ scalars.append(s)
+ vectors.append(v)
+
+ # Test if code in uniform to allow for vectorization
+ if numel(set(scalars)) != 1:
+ raise ValueError("non uniform scalar accesses in consecutive line. can not vectorize")
+
+ for j in range(1, 4):
+ for v, idx in vectors[j].items():
+ if Eval(Sub(idx, vectors[j-1][v])) != 1:
+ raise ValueError("non uniform vector accesses in consecutive line. can not vectorize")
+
+ # If we made it to here, we have passed all checks. vectorize!
+ # Create symbol table
+ if i == 0:
+ for s in scalars[0]:
+ s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
+
+ for v in vectors[0].keys():
+ v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256")
+
+ # Necessary loads into registers
+ # NOTE: Generalization of a Deref
+ # for v in vectors[0].keys():
+ # load = SIMDLoadAt(symtab[v.name], Add(v, I(i)))
+ # loop.execute(load)
+
+ # IMPORTANT: We do a post-order traversal.
+ # The identifier will be substituted first!
+ def translate(x):
+ if type(x) == Ident:
+ if x.name in s_symtab:
+ return s_symtab[x.name]
+ elif x.name in v_symtab:
+ x.var = v_symtab[x.name]
+ if type(x) == Index:
+ if type(x.x) == Ident and x.x.name in v_symtab:
+ return Add(x.x, x.i)
+
+ # if type(x) == Index and x.x in set(symtab.values()):
+ # return x.x
+ return x
+
+ loop.execute(Make(instr[i], lambda node: Transform(node, translate)))
vfunc.execute(loop)
else: