From 01b68aff4853c3a4b4349675f2f40575a4538fff Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sun, 10 May 2020 21:06:45 -0700 Subject: fix: ergonomics --- lib/c.py | 109 ++++++++++++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 97 insertions(+), 12 deletions(-) (limited to 'lib') diff --git a/lib/c.py b/lib/c.py index 54b8e14..1554cc4 100644 --- a/lib/c.py +++ b/lib/c.py @@ -35,6 +35,12 @@ def exits_scope(): emitln() emit("}") +def emitheader(): + emit("#include \n") + emit("#include \n") + emit("#include \n") + emitln() + # ------------------------------------------------------------------------ # Simple C AST # TODO: Type checking @@ -97,6 +103,8 @@ Void = Base("void") Error = Base("error") Int = Base("int") +Int8 = Base("int8") +Int16 = Base("int16") Int32 = Base("int32") Int64 = Base("int64") Int32x4 = Base("__mm128i") @@ -111,6 +119,44 @@ Float32x8 = Base("__mm256") Float64x2 = Base("__m128d") Float64x4 = Base("__mm256d") +# TODO: Make this ARCH dependent +BitDepth= { + Void: 8, + Int: 32, + Int32: 32, + Int64: 64, + Float32: 64, + Float64: 64, +} + +class SIMD(Enum): + SSE = "sse" + SSE2 = "sse2" + AVX = "avx" + AVX2 = "avx2" + FMA3 = "fma3" + AVX5 = "avx512" + +RegisterSize = { + SIMD.SSE: 128, + SIMD.SSE2: 128, + SIMD.AVX: 256, + SIMD.AVX2: 256, + SIMD.FMA3: 256, + SIMD.AVX5: 512, +} + +SIMDSupport = { + SIMD.SSE: set([Float32]), + SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), + SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), + SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), + SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), + SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]), +} + +# TODO: Operation lookup tables... + class Ptr(Type): def __init__(self, to: Type): self.to = to @@ -365,7 +411,7 @@ class Call(Expr): self.args = args def emit(self): - emit(self.func.ident) + emit(self.func.name) emit("(") if numel(self.args) > 0: self.args[0].emit() @@ -408,11 +454,11 @@ class Block(Stmt): exits_scope() class For(Stmt): - def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt): + def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None): self.init = init if isinstance(init, Expr) else ExprList(*init) self.cond = cond if isinstance(cond, Expr) else ExprList(*cond) self.step = step if isinstance(step, Expr) else ExprList(*step) - self.body = body + self.body = body if body is not None else Empty() def emit(self): emit("for (") @@ -466,8 +512,8 @@ class Decl(Emitter): pass class Func(Decl): - def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): - self.ident = ident + def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None): + self.name = name self.ret = ret self.params = params if params is not None else [] self.vars = vars if vars is not None else [] @@ -476,7 +522,7 @@ class Func(Decl): def emit(self): self.ret.emit() emitln() - emit(self.ident) + emit(self.name) emit("(") for i, p in enumerate(self.params): p.emittype() @@ -577,6 +623,11 @@ def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]: def Swap(x: Var, y: Var, tmp: Var) -> Stmt: return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp))) +def IsLoop(s: Stmt) -> bool: + return type(s) == For + +# TODO: Generalize to memory addresses! +# This will allow better vectorization @singledispatch def VarsUsed(s: object) -> List[Vars]: raise TypeError(f"{type(s)} not supported by VarsUsed operation") @@ -600,6 +651,13 @@ def _(s: Assign): def _(lit: Literal): return [] +@VarsUsed.register +def _(i: Index): + vars = [] + vars.extend(VarsUsed(i.x)) + vars.extend(VarsUsed(i.i)) + return vars + @VarsUsed.register def _(comma: Comma): vars = [] @@ -679,6 +737,10 @@ def _(x: Ident, i: int): def _(x: Deref, i: int): return Index(x.x, I(i)) +@Step.register +def _(ix: Index, i: int): + return Index(ix.x, I(ix.i + i)) + def Expand(s: StmtExpr, times: int) -> Block: if not isinstance(s, StmtExpr): raise TypeError(f"{type(x)} not supported by Expand operation") @@ -706,7 +768,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun raise TypeError(f"{type(n)} not implemented yet") params = [n] + params - kernel = Func(f"{name}{times}", Int, params, stacks) + kernel = Func(f"{name}_kernel{times}", Int, params, stacks) body = loop.body kernel.execute( @@ -720,32 +782,55 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun loop.init = None return loop, kernel, Call(kernel, params) +# Replaces all vectorizable loops inside Func with vectorized variants +# Returns the new vectorized function (tagged by the SIMD chosen) +def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func: + if isa != SIMD.AVX2: + raise ValueError(f"ISA '{isa}' not currently implemented") + + vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params, func.vars) + for stmt in func.stmts: + if IsLoop(stmt): + loop = For(stmt.init, stmt.cond, stmt.step) + vfunc.execute(loop) + else: + vfunc.execute(stmt) + + return vfunc + # ------------------------------------------------------------------------ # Point of testing if __name__ == "__main__": + emitheader() + Rot = Func("blas·swap", Void, - params = Params( + Params( (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y") ), - vars = Vars( + Vars( (Int, "i"), (Float64, "tmp") ) ) + # could also declare like # Rot.declare( ... ) # TODO: Increase ergonomics here... len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp") loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)], - body = Swap(Deref(x), Deref(y), tmp) + body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) ) - rem, kernel, it = Unroll(loop, 8, "swap_kernel") + rem, kernel, call = Unroll(loop, 8, "swap") kernel.emit() emitln(2) - Rot.execute(Set(i, it), rem) + avx256kernel = Vectorize(kernel, [], SIMD.AVX2) + avx256kernel.emit() + emitln(2) + + Rot.execute(Set(i, call), rem) Rot.emit() -- cgit v1.2.1