aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-10 21:06:45 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-10 21:06:45 -0700
commit01b68aff4853c3a4b4349675f2f40575a4538fff (patch)
treeb746942a6166ac29aaf5c90325801282a82296b0
parentacf54a4b990634e23a88b04363923c4b108a0304 (diff)
fix: ergonomics
-rw-r--r--lib/c.py109
1 files changed, 97 insertions, 12 deletions
diff --git a/lib/c.py b/lib/c.py
index 54b8e14..1554cc4 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -35,6 +35,12 @@ def exits_scope():
emitln()
emit("}")
+def emitheader():
+ emit("#include <u.h>\n")
+ emit("#include <libn.h>\n")
+ emit("#include <libmath.h>\n")
+ emitln()
+
# ------------------------------------------------------------------------
# Simple C AST
# TODO: Type checking
@@ -97,6 +103,8 @@ Void = Base("void")
Error = Base("error")
Int = Base("int")
+Int8 = Base("int8")
+Int16 = Base("int16")
Int32 = Base("int32")
Int64 = Base("int64")
Int32x4 = Base("__mm128i")
@@ -111,6 +119,44 @@ Float32x8 = Base("__mm256")
Float64x2 = Base("__m128d")
Float64x4 = Base("__mm256d")
+# TODO: Make this ARCH dependent
+BitDepth= {
+ Void: 8,
+ Int: 32,
+ Int32: 32,
+ Int64: 64,
+ Float32: 64,
+ Float64: 64,
+}
+
+class SIMD(Enum):
+ SSE = "sse"
+ SSE2 = "sse2"
+ AVX = "avx"
+ AVX2 = "avx2"
+ FMA3 = "fma3"
+ AVX5 = "avx512"
+
+RegisterSize = {
+ SIMD.SSE: 128,
+ SIMD.SSE2: 128,
+ SIMD.AVX: 256,
+ SIMD.AVX2: 256,
+ SIMD.FMA3: 256,
+ SIMD.AVX5: 512,
+}
+
+SIMDSupport = {
+ SIMD.SSE: set([Float32]),
+ SIMD.SSE2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
+ SIMD.AVX: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
+ SIMD.AVX2: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
+ SIMD.FMA3: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
+ SIMD.AVX5: set([Float32, Float64, Int, Int8, Int16, Int32, Int64]),
+}
+
+# TODO: Operation lookup tables...
+
class Ptr(Type):
def __init__(self, to: Type):
self.to = to
@@ -365,7 +411,7 @@ class Call(Expr):
self.args = args
def emit(self):
- emit(self.func.ident)
+ emit(self.func.name)
emit("(")
if numel(self.args) > 0:
self.args[0].emit()
@@ -408,11 +454,11 @@ class Block(Stmt):
exits_scope()
class For(Stmt):
- def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt):
+ def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None):
self.init = init if isinstance(init, Expr) else ExprList(*init)
self.cond = cond if isinstance(cond, Expr) else ExprList(*cond)
self.step = step if isinstance(step, Expr) else ExprList(*step)
- self.body = body
+ self.body = body if body is not None else Empty()
def emit(self):
emit("for (")
@@ -466,8 +512,8 @@ class Decl(Emitter):
pass
class Func(Decl):
- def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None):
- self.ident = ident
+ def __init__(self, name: str, ret: Expr = Void, params: List[Param] = None, vars: List[Var | List[Var]] = None, body: List[Stmt] = None):
+ self.name = name
self.ret = ret
self.params = params if params is not None else []
self.vars = vars if vars is not None else []
@@ -476,7 +522,7 @@ class Func(Decl):
def emit(self):
self.ret.emit()
emitln()
- emit(self.ident)
+ emit(self.name)
emit("(")
for i, p in enumerate(self.params):
p.emittype()
@@ -577,6 +623,11 @@ def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]:
def Swap(x: Var, y: Var, tmp: Var) -> Stmt:
return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp)))
+def IsLoop(s: Stmt) -> bool:
+ return type(s) == For
+
+# TODO: Generalize to memory addresses!
+# This will allow better vectorization
@singledispatch
def VarsUsed(s: object) -> List[Vars]:
raise TypeError(f"{type(s)} not supported by VarsUsed operation")
@@ -601,6 +652,13 @@ def _(lit: Literal):
return []
@VarsUsed.register
+def _(i: Index):
+ vars = []
+ vars.extend(VarsUsed(i.x))
+ vars.extend(VarsUsed(i.i))
+ return vars
+
+@VarsUsed.register
def _(comma: Comma):
vars = []
vars.extend(VarsUsed(comma.expr[0]))
@@ -679,6 +737,10 @@ def _(x: Ident, i: int):
def _(x: Deref, i: int):
return Index(x.x, I(i))
+@Step.register
+def _(ix: Index, i: int):
+ return Index(ix.x, I(ix.i + i))
+
def Expand(s: StmtExpr, times: int) -> Block:
if not isinstance(s, StmtExpr):
raise TypeError(f"{type(x)} not supported by Expand operation")
@@ -706,7 +768,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
raise TypeError(f"{type(n)} not implemented yet")
params = [n] + params
- kernel = Func(f"{name}{times}", Int, params, stacks)
+ kernel = Func(f"{name}_kernel{times}", Int, params, stacks)
body = loop.body
kernel.execute(
@@ -720,32 +782,55 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
loop.init = None
return loop, kernel, Call(kernel, params)
+# Replaces all vectorizable loops inside Func with vectorized variants
+# Returns the new vectorized function (tagged by the SIMD chosen)
+def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func:
+ if isa != SIMD.AVX2:
+ raise ValueError(f"ISA '{isa}' not currently implemented")
+
+ vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params, func.vars)
+ for stmt in func.stmts:
+ if IsLoop(stmt):
+ loop = For(stmt.init, stmt.cond, stmt.step)
+ vfunc.execute(loop)
+ else:
+ vfunc.execute(stmt)
+
+ return vfunc
+
# ------------------------------------------------------------------------
# Point of testing
if __name__ == "__main__":
+ emitheader()
+
Rot = Func("blas·swap", Void,
- params = Params(
+ Params(
(Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
),
- vars = Vars(
+ Vars(
(Int, "i"), (Float64, "tmp")
)
)
+ # could also declare like
# Rot.declare( ... )
# TODO: Increase ergonomics here...
len, x, y, i, tmp = Rot.variables("len", "x", "y", "i", "tmp")
loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)],
- body = Swap(Deref(x), Deref(y), tmp)
+ body = Swap(Index(x, I(0)), Index(y, I(0)), tmp)
)
- rem, kernel, it = Unroll(loop, 8, "swap_kernel")
+ rem, kernel, call = Unroll(loop, 8, "swap")
kernel.emit()
emitln(2)
- Rot.execute(Set(i, it), rem)
+ avx256kernel = Vectorize(kernel, [], SIMD.AVX2)
+ avx256kernel.emit()
+ emitln(2)
+
+ Rot.execute(Set(i, call), rem)
Rot.emit()