aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-11 14:17:11 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-11 14:17:11 -0700
commitfd50eff732d45ba2c4400fabf301c65c776f2338 (patch)
treef588effc028d57636a2e576ef0e057a44b31a1a5
parent01b68aff4853c3a4b4349675f2f40575a4538fff (diff)
refactor: pulled out statement BFS code from expr eval code
-rw-r--r--lib/c.py194
1 files changed, 118 insertions, 76 deletions
diff --git a/lib/c.py b/lib/c.py
index 1554cc4..518952c 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -620,85 +620,39 @@ def Vars(*ps: List[Tuple(Type, str)]) -> List[Var]:
# ------------------------------------------------------------------------
# AST modification/production functions
+# ------------------------------------------
+# basic (non-recursive) commands
+
def Swap(x: Var, y: Var, tmp: Var) -> Stmt:
return StmtExpr(ExprList(Set(tmp, x), Set(x, y), Set(y, tmp)))
def IsLoop(s: Stmt) -> bool:
return type(s) == For
-# TODO: Generalize to memory addresses!
-# This will allow better vectorization
-@singledispatch
-def VarsUsed(s: object) -> List[Vars]:
- raise TypeError(f"{type(s)} not supported by VarsUsed operation")
-
-@VarsUsed.register
-def _(op: UnaryOp):
- return VarsUsed(op.x)
-
-@VarsUsed.register
-def _(sym: Ident):
- return [sym.var]
-
-@VarsUsed.register
-def _(s: Assign):
- vars = []
- vars.extend(VarsUsed(s.lhs))
- vars.extend(VarsUsed(s.rhs))
- return vars
-
-@VarsUsed.register
-def _(lit: Literal):
- return []
-
-@VarsUsed.register
-def _(i: Index):
- vars = []
- vars.extend(VarsUsed(i.x))
- vars.extend(VarsUsed(i.i))
- return vars
-
-@VarsUsed.register
-def _(comma: Comma):
- vars = []
- vars.extend(VarsUsed(comma.expr[0]))
- vars.extend(VarsUsed(comma.expr[1]))
- return vars
+def Expand(s: StmtExpr, times: int) -> Block:
+ if not isinstance(s, StmtExpr):
+ raise TypeError(f"{type(x)} not supported by Expand operation")
-@VarsUsed.register
-def _(op: BinaryOp):
- vars = []
- vars.extend(VarsUsed(op.l))
- vars.extend(VarsUsed(op.r))
- return vars
+ return Block([StmtExpr(Step(s.x, i)) for i in range(times)])
-@VarsUsed.register
-def _(s: Empty):
- return []
+def EvenTo(x: Var, n: int) -> Var:
+ return And(x, Negate(I(n-1)))
-@VarsUsed.register
-def _(blk: Block):
- vars = []
- for stmt in blk.stmts:
- vars.extend(VarsUsed(stmt))
- return vars
+@singledispatch
+def Repeat(x: Expr, times: int) -> Expr | List[Expr]:
+ raise TypeError(f"{type(x)} not supported by Repeat operation")
-@VarsUsed.register
-def _(loop: For):
- vars = []
- vars.extend(VarsUsed(loop.init))
- vars.extend(VarsUsed(loop.cond))
- vars.extend(VarsUsed(loop.step))
- vars.extend(VarsUsed(loop.body))
- return vars
+@Repeat.register
+def _(x: Inc, times: int):
+ return AddSet(x.x, I(times))
-@VarsUsed.register
-def _(ret: Return):
- return VarsUsed(ret.val)
+@Repeat.register
+def _(x: Dec, times: int):
+ return DecSet(x.x, I(times))
-@VarsUsed.register
-def _(x: StmtExpr):
- return VarsUsed(x.x)
+@Repeat.register
+def _(x: Comma, times: int):
+ return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times))
@singledispatch
def Repeat(x: Expr, times: int) -> Expr | List[Expr]:
@@ -741,14 +695,83 @@ def _(x: Deref, i: int):
def _(ix: Index, i: int):
return Index(ix.x, I(ix.i + i))
-def Expand(s: StmtExpr, times: int) -> Block:
- if not isinstance(s, StmtExpr):
- raise TypeError(f"{type(x)} not supported by Expand operation")
+# ------------------------------------------
+# bfs search on statements in ast
- return Block([StmtExpr(Step(s.x, i)) for i in range(times)])
+@singledispatch
+def Visit(s: Stmt, func):
+ raise TypeError(f"{type(s)} not supported by VarsAccessed operation")
-def EvenTo(x: Var, n: int) -> Var:
- return And(x, Negate(I(n-1)))
+@Visit.register
+def _(s: Empty, func):
+ return
+
+@Visit.register
+def _(blk: Block, func):
+ for stmt in blk.stmts:
+ func(stmt)
+
+@Visit.register
+def _(loop: For, func):
+ func(loop.init)
+ func(loop.cond)
+ func(loop.step)
+ func(loop.body)
+
+@Visit.register
+def _(ret: Return, func):
+ func(ret.val)
+
+@Visit.register
+def _(x: StmtExpr, func):
+ func(x.x)
+
+# ------------------------------------------
+# expression functions
+
+@singledispatch
+def VarsAccessed(x: Expr, vars: List[Vars]):
+ raise TypeError(f"{type(s)} not supported by VarsAccessed operation")
+
+@VarsAccessed.register(Empty)
+@VarsAccessed.register(Literal)
+def _(x, vars):
+ return
+
+@VarsAccessed.register
+def _(op: UnaryOp, vars):
+ return VarsAccessed(op.x, vars)
+
+@VarsAccessed.register
+def _(sym: Ident, vars):
+ vars.append(sym.var)
+
+@VarsAccessed.register
+def _(s: Assign, vars):
+ VarsAccessed(s.lhs, vars)
+ VarsAccessed(s.rhs, vars)
+
+@VarsAccessed.register
+def _(v: Deref, vars):
+ VarsAccessed(v.x, vars)
+
+@VarsAccessed.register
+def _(i: Index, vars):
+ VarsAccessed(i.x, vars)
+ VarsAccessed(i.i, vars)
+
+@VarsAccessed.register
+def _(comma: Comma, vars):
+ VarsAccessed(comma.expr[0], vars)
+ VarsAccessed(comma.expr[1], vars)
+
+@VarsAccessed.register
+def _(op: BinaryOp, vars):
+ VarsAccessed(op.l, vars)
+ VarsAccessed(op.r, vars)
+
+# ------------------------------------------
+# Large scale functions
def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Func, Call):
# TODO: More sophisticated computation for length of loop
@@ -757,7 +780,9 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
# pull off needed features of the loop
it = loop.init.lhs.var
- vars = set(VarsUsed(loop.body))
+ vars = []
+ Visit(loop.body, lambda node: VarsAccessed(node, vars))
+ print(vars)
params = [v for v in vars if type(v) == Param]
stacks = [v for v in vars if type(v) == Var]
@@ -784,7 +809,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
# Replaces all vectorizable loops inside Func with vectorized variants
# Returns the new vectorized function (tagged by the SIMD chosen)
-def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func:
+def Vectorize(func: Func, isa: SIMD) -> Func:
if isa != SIMD.AVX2:
raise ValueError(f"ISA '{isa}' not currently implemented")
@@ -792,12 +817,29 @@ def Vectorize(func: Func, arrays: List[Var], isa: SIMD) -> Func:
for stmt in func.stmts:
if IsLoop(stmt):
loop = For(stmt.init, stmt.cond, stmt.step)
+ body = stmt.body
+ if type(body) != Block:
+ print("could not vectorize loop, skipping")
+ loop.body = body
+ else:
+ instrs = body.stmts
+ # TODO: Remove hardcoded 4
+ if numel(instrs) % 4 != 0:
+ raise ValueError("loop can not be vectorized, instructions can not be globbed equally")
+ # TODO: Allow for non-sequential accesses?
+ # for i in range(numel(instrs)/4):
+
+
vfunc.execute(loop)
else:
vfunc.execute(stmt)
return vfunc
+
+def Strided(func: Func) -> Func:
+ pass
+
# ------------------------------------------------------------------------
# Point of testing
@@ -826,7 +868,7 @@ if __name__ == "__main__":
kernel.emit()
emitln(2)
- avx256kernel = Vectorize(kernel, [], SIMD.AVX2)
+ avx256kernel = Vectorize(kernel, SIMD.AVX2)
avx256kernel.emit()
emitln(2)