From 2f4de9abb28ecb5af7390e7011a40f493c4b4dae Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Mon, 11 May 2020 17:52:07 -0700 Subject: feat: begun adding lookup table of functions --- lib/c.py | 481 ++++++++++++++++++++++++++++++++++++++++++++++++++++++--------- 1 file changed, 416 insertions(+), 65 deletions(-) diff --git a/lib/c.py b/lib/c.py index 25c97b2..070ade6 100644 --- a/lib/c.py +++ b/lib/c.py @@ -5,10 +5,11 @@ import os import sys from enum import Enum -from typing import Set, List, Tuple, Dict +from typing import List, Tuple, Dict from functools import singledispatch numel = len + # ------------------------------------------------------------------------ # String buffer @@ -53,33 +54,6 @@ class Emitter(object): # ------------------------------------------ # Representation of a C type -class TypeKind(Enum): - Void = "void" - Error = "error" - - Int = "int" - Int32 = "int32" - Int64 = "int64" - - # vectorized variants - Int32x4 = "__mm128i" - Int32x8 = "__mm256i" - Int64x2 = "__mm128i" - Int64x4 = "__mm256i" - - Float32 = "float" - Float64 = "double" - - # vectorized variants - Float32x4 = "__m128" - Float32x8 = "__mm256" - Float64x2 = "__mm128d" - Float64x4 = "__mm256d" - - Pointer = "pointer" - Struct = "struct" - Enum = "enum" - Union = "union" class Type(Emitter): def emit(self): @@ -92,16 +66,71 @@ class Base(Type): def __init__(self, name: str): self.name = name + def __hash__(self): + return hash(self.name) + + def __eq__(self, other): + return type(self) == type(other) and self.name == other.name + + def __str__(self): + return self.name + def emit(self): emit(self.name) def emitspec(self, ident): emit(f"{ident}") +# TODO: Operation lookup tables... + +class Ptr(Type): + def __init__(self, to: Type): + self.to = to + + def __str__(self): + return f"*{self.to}" + + def __eq__(self, other): + return type(self) == type(other) and self.to == other.to + + def __hash__(self): + return hash((hash(self.to), "*")) + + def emit(self): + self.to.emit() + + def emitspec(self, ident): + emit("*") + self.to.emitspec(ident) + +class Array(Type): + def __init__(self, base: Type, len: int): + self.base = base + self.len = len + + def __str__(self): + return f"[{self.len}]{self.base}" + + def __eq__(self, other): + return type(self) == type(other) and self.base == other.base + + def __hash__(self): + return hash((hash(self.base), self.len)) + + def emit(self): + self.base.emit() + + def emitspec(self, ident): + self.base.emitspec(ident) + emit(f"[{self.len}]") + # Machine primitive types Void = Base("void") Error = Base("error") +Byte = Base("byte") +String = Ptr(Byte) + Int = Base("int") Int8 = Base("int8") Int16 = Base("int16") @@ -119,6 +148,18 @@ Float32x8 = Base("__mm256") Float64x2 = Base("__m128d") Float64x4 = Base("__mm256d") +def VectorType(kind): + if kind is Float32x4 or \ + kind is Float32x8 or \ + kind is Float64x2 or \ + kind is Float64x4 or \ + kind is Int32x4 or \ + kind is Int32x8 or \ + kind is Int64x2 or \ + kind is Int64x4: + return True + return False + # TODO: Make this ARCH dependent BitDepth= { Void: 8, @@ -155,31 +196,6 @@ SIMDSupport = { SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), } -# TODO: Operation lookup tables... - -class Ptr(Type): - def __init__(self, to: Type): - self.to = to - - def emit(self): - self.to.emit() - - def emitspec(self, ident): - emit("*") - self.to.emitspec(ident) - -class Array(Type): - def __init__(self, base: Type, len: int): - self.base = base - self.len = len - - def emit(self): - self.base.emit() - - def emitspec(self, ident): - self.base.emitspec(ident) - emit(f"[{self.len}]") - # TODO: Typedefs... # ------------------------------------------ @@ -309,7 +325,6 @@ class Xor(BinaryOp): emit(f" ^ ") self.r.emit() - class GT(BinaryOp): def emit(self): self.l.emit() @@ -340,17 +355,64 @@ class EQ(BinaryOp): emit(f" == ") self.r.emit() +# Loads a scalar +class SIMDLoad(BinaryOp): + def emit(self): + self.l.emit() + self.r.emit() + +# Loads data at address +class SIMDLoadAt(BinaryOp): + def emit(self): + self.l.emit() + emit(" = _mm256_loadu_pd(") + self.r.emit() + emit(")") + # Assignment (stores) class Assign(Expr): def __init__(self, lhs: Expr, rhs: Expr): self.lhs = lhs self.rhs = rhs +def emit·load256(l, r): + emit("_mm256_loadu_pd(") + l.emit() + emit(", ") + r.emit() + emit(")") + +def emit·store256(l, r): + emit("_mm256_storeu_pd(") + l.emit() + emit(", ") + r.emit() + emit(")") + +def emit·copy256(l, r): + emit("_mm256_storeu_pd(") + l.emit() + emit(", ") + emit("_mm256_loadu_pd(") + r.emit() + emit("))") + +SetTypeDispatch = { + (Float64x4, Ptr(Float64x4)) : emit·load256, + (Ptr(Float64x4), Float64x4) : emit·store256, + (Ptr(Float64x4), Ptr(Float64x4)) : emit·copy256, +} + class Set(Assign): def emit(self): - self.lhs.emit() - emit(f" = ") - self.rhs.emit() + lhs = GetType(self.lhs) + rhs = GetType(self.rhs) + if (lhs, rhs) in SetTypeDispatch: + SetTypeDispatch[(lhs, rhs)](self.lhs, self.rhs) + else: + self.lhs.emit() + emit(f" = ") + self.rhs.emit() class AddSet(Assign): def emit(self): @@ -479,6 +541,17 @@ class For(Stmt): self.body.emit() exits_scope() + def execute(self, stmt: Stmt | Expr): + if isinstance(stmt, Expr): + stmt = StmtExpr(stmt) + + if isinstance(self.body, Empty): + self.body = stmt + return + elif not isinstance(self.body, Block): + self.body = Block([self.body]) + self.body.stmts.append(stmt) + class Return(Stmt): def __init__(self, val: Expr): self.val = val @@ -563,6 +636,9 @@ class Func(Decl): exits_scope() def declare(self, var: Var, *vars: List[Var | List[Var]]) -> Expr | List[Expr]: + if var.name in [v.name for v in self.vars]: + return + self.vars.append(var) if numel(vars) == 0: return Ident(var) @@ -727,7 +803,163 @@ def _(x: StmtExpr, func): func(x.x) # ------------------------------------------ -# expression functions +# recreates a piece of an AST, allowing an arbitrary transformation of expression nodes + +@singledispatch +def Make(s: Stmt, func): + raise TypeError(f"{type(s)} not supported by Make operation") + +@Make.register +def _(s: Empty, func): + return Empty() + +@Make.register +def _(blk: Block, func): + return Block([func(stmt) for stmt in blk.Stmts]) + +@Make.register +def _(loop: For, func): + return For(func(loop.init), func(loop.cond), func(loop.step), func(loop.body)) + +@Make.register +def _(ret: Return, func): + return Return(func(ret.val)) + +@Make.register +def _(x: StmtExpr, func): + return StmtExpr(func(x.x)) + +# ------------------------------------------ +# GetType function + +@singledispatch +def GetType(x: Expr) -> Type: + raise TypeError(f"{type(x)} not supported by GetType operation") + +@GetType.register +def _(x: Empty): + return Void + +@GetType.register +def _(x: Empty): + return Void + +@GetType.register +def _(x: S): + return Ptr(Byte) + +@GetType.register +def _(x: I): + return Int + +@GetType.register +def _(x: F): + return Float64 + +@GetType.register +def _(x: Ident): + return x.var.type + +@GetType.register +def _(x: UnaryOp): + return GetType(x.x) + +@GetType.register +def _(x: Deref): + return GetType(x.x).to + +@GetType.register +def _(x: Ref): + return Ptr(GetType(x.x)) + +@GetType.register +def _(x: Index): + base = GetType(x.x) + if isinstance(base, Ptr): + return base.to + elif isinstance(base, Array): + return base.base + else: + raise TypeError(f"attempting to index type {base}") + +# TODO: type checking for both + +@GetType.register +def _(x: BinaryOp): + lhs = GetType(x.l) + rhs = GetType(x.r) + return lhs + +@GetType.register +def _(x: Var): + return x.type + +@GetType.register +def _(x: Assign): + lhs = GetType(x.lhs) + rhs = GetType(x.rhs) + return lhs + +@GetType.register +def _(x: Comma): + return GetType(x.expr[1]) + +@GetType.register +def _(x: Paren): + return GetType(x.x) + +@GetType.register +def _(x: Call): + return x.func.ret + +# ------------------------------------------ +# Transform function +# Recurses down AST, and modifies Expr nodes +# Returns a brand new tree structure + +@singledispatch +def Transform(x: Expr, func): + raise TypeError(f"{type(x)} not supported by Transform operation") + +@Transform.register +def _(x: Empty, func): + return func(Empty()) + +@Transform.register +def _(x: Literal, func): + return func(type(x)(x)) + +@Transform.register +def _(x: Ident, func): + return func(Ident(x.var)) + +@Transform.register +def _(x: UnaryOp, func): + return func(UnaryOp(Transform(x.x, func))) + +@Transform.register +def _(x: Assign, func): + return func(type(x)(Transform(x.lhs, func), Transform(x.rhs, func))) + +@Transform.register +def _(x: Deref, func): + return func(Transform(Deref(x.x, func))) + +@Transform.register +def _(x: Index, func): + return func(Index(Transform(x.x, func), Transform(x.i, func))) + +@Transform.register +def _(x: Comma, func): + return func(Comma(Transform(x.expr[0], func), Transform(x.expr[1], func))) + +@Transform.register +def _(x: BinaryOp, func): + return func(BinaryOp(Transform(x.l, func), Transform(x.r, func))) + +# ------------------------------------------ +# Filter function +# Recurses down AST, and stores Expr nodes that satisfy condition in results @singledispatch def Filter(x: Expr, cond, results: List[Expr]): @@ -784,13 +1016,82 @@ def _(op: BinaryOp, cond, results: List[Expr]): Filter(op.l, cond, results) Filter(op.r, cond, results) -def VarsUsed(stmt: Stmt) -> Set[Var]: +# ------------------------------------------ +# Eval function +# Recurses down AST, and evaluates Expr nodes +# Throws an error if tree can't be evaluated at compile time + +@singledispatch +def Eval(x: Expr) -> object: + raise TypeError(f"{type(s)} not supported by Eval operation") + +@Eval.register +def _(x: Empty): + pass + +@Eval.register(Ident) +@Eval.register(Literal) +def _(x): + return x + +@Eval.register +def _(x: Inc): + return Eval(Add(Eval(op.x, cond, results), I(1))) + +@Eval.register +def _(x: Dec): + return Eval(Sub(Eval(op.x, cond, results), I(1))) + +# TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x)) +@Eval.register +def _(op: Add): + l = Eval(op.l) + r = Eval(op.r) + return l + r + +@Eval.register +def _(op: Sub): + l = Eval(op.l) + r = Eval(op.r) + # TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x)) + return l - r + +@Eval.register +def _(s: Assign): + return Eval(s.rhs) + +@Eval.register +def _(comma: Comma): + return Eval(comma.expr[1]) + +# ------------------------------------------ +# Leaf traversal + +def VarsUsed(stmt: Stmt) -> List[Var]: vars = [] - Visit(loop.body, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars)) + Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), vars)) vars = set([v.var for v in vars]) return vars +def AddrAccessed(stmt: Stmt): + scalars = [] # variables that are accessed as scalars (single sites) + vectors = {} # variables that are accessed as vectors (indexed/dereferenced) + nodes = [] + Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Ident), scalars)) + Visit(stmt, lambda node: Filter(node, lambda x: isinstance(x, Index), nodes)) + + for node in nodes: + vars = [] + Filter(node.x, lambda x: isinstance(x, Ident), vars) + if numel(vars) != 1: + raise ValueError("multiple variables used in index expression not supported") + vectors[vars[0]] = node.i + if vars[0] in scalars: + scalars.remove(vars[0]) + + return frozenset(scalars), vectors + # ------------------------------------------ # Large scale functions @@ -841,13 +1142,63 @@ def Vectorize(func: Func, isa: SIMD) -> Func: print("could not vectorize loop, skipping") loop.body = body else: - instrs = body.stmts - # TODO: Remove hardcoded 4 - if numel(instrs) % 4 != 0: + instr = body.stmts + # TODO: Remove hardcoded 4 -> should be function of types! + # As of now, we have harcoded AVX2 <-> float64 relationship + if numel(instr) % 4 != 0: raise ValueError("loop can not be vectorized, instructions can not be globbed equally") - # TODO: Allow for non-sequential accesses? - # for i in range(numel(instrs)/4): + # TODO: Allow for non-sequential accesses? + s_symtab, v_symtab = {}, {} + for i in range(0, numel(instr), 4): + scalars = [] + vectors = [] + for j in range(4): + s, v = AddrAccessed(instr[i+j]) + scalars.append(s) + vectors.append(v) + + # Test if code in uniform to allow for vectorization + if numel(set(scalars)) != 1: + raise ValueError("non uniform scalar accesses in consecutive line. can not vectorize") + + for j in range(1, 4): + for v, idx in vectors[j].items(): + if Eval(Sub(idx, vectors[j-1][v])) != 1: + raise ValueError("non uniform vector accesses in consecutive line. can not vectorize") + + # If we made it to here, we have passed all checks. vectorize! + # Create symbol table + if i == 0: + for s in scalars[0]: + s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256")) + + for v in vectors[0].keys(): + v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256") + + # Necessary loads into registers + # NOTE: Generalization of a Deref + # for v in vectors[0].keys(): + # load = SIMDLoadAt(symtab[v.name], Add(v, I(i))) + # loop.execute(load) + + # IMPORTANT: We do a post-order traversal. + # The identifier will be substituted first! + def translate(x): + if type(x) == Ident: + if x.name in s_symtab: + return s_symtab[x.name] + elif x.name in v_symtab: + x.var = v_symtab[x.name] + if type(x) == Index: + if type(x.x) == Ident and x.x.name in v_symtab: + return Add(x.x, x.i) + + # if type(x) == Index and x.x in set(symtab.values()): + # return x.x + return x + + loop.execute(Make(instr[i], lambda node: Transform(node, translate))) vfunc.execute(loop) else: -- cgit v1.2.1