aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-12 11:12:58 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-12 11:12:58 -0700
commit4e4a3ae1b19611f0367624f541efe91aff570fab (patch)
tree1afbc6446e9778d216d59f565469c2db541395c3
parent005feda6c41eaf28d8702aff1b6f1e79493eae78 (diff)
fix: some type errors on the AST
-rw-r--r--lib/c.py62
1 files changed, 37 insertions, 25 deletions
diff --git a/lib/c.py b/lib/c.py
index c963b5b..d88f7ab 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -538,20 +538,16 @@ class Stmt(Emitter):
emit(";")
class Empty(Stmt):
- def __init__(self):
- pass
- def emit(self):
- super(Empty, self).emit()
+ pass
class Block(Stmt):
- def __init__(self, *stmts: List[Stmt]):
- self.stmts = list(stmts)
+ def __init__(self, *stmts: List[Stmt | Expr]):
+ self.stmts = [ s if isinstance(s, Stmt) else StmtExpr(s) for s in stmts ]
def emit(self):
enter_scope()
for i, stmt in enumerate(self.stmts):
stmt.emit()
- emit(";")
if i < numel(self.stmts) - 1:
emitln()
@@ -605,7 +601,7 @@ class For(Stmt):
self.body = stmt
return
elif not isinstance(self.body, Block):
- self.body = Block([self.body])
+ self.body = Block(self.body)
self.body.stmts.append(stmt)
class Return(Stmt):
@@ -764,15 +760,21 @@ def Swap(x: Var, y: Var, tmp: Var) -> Stmt:
def IsLoop(s: Stmt) -> bool:
return type(s) == For
-def Expand(s: StmtExpr, times: int) -> Block:
- if not isinstance(s, StmtExpr):
- raise TypeError(f"{type(x)} not supported by Expand operation")
-
- return Block(*[StmtExpr(Step(s.x, i)) for i in range(times)])
-
def EvenTo(x: Var, n: int) -> Var:
return And(x, Negate(I(n-1)))
+# ------------------------------------------
+# Expand: takes statement, indexed, and expands it x times
+
+# def Expand(s: Stmt, times: int) -> Block:
+# if not isinstance(s, StmtExpr):
+# raise TypeError(f"{type(x)} not supported by Expand operation")
+
+# return Block(StmtExpr(Step(s.x, i)) for i in range(times))
+
+# ------------------------------------------
+# repeat: takes command on an array and repeats it
+
@singledispatch
def Repeat(x: Expr, times: int) -> Expr | List[Expr]:
raise TypeError(f"{type(x)} not supported by Repeat operation")
@@ -805,7 +807,9 @@ def _(x: Dec, times: int):
def _(x: Comma, times: int):
return Comma(Repeat(x.expr[0], times), Repeat(x.expr[1], times))
-# TODO: Parameterize the variables that should not be expanded
+# ------------------------------------------
+# step: indexes an expression by i
+
@singledispatch
def Step(x: Expr, i: int) -> Expr:
raise TypeError(f"{type(x)} not supported by Step operation")
@@ -852,21 +856,21 @@ def _(s: Empty, func):
@Visit.register
def _(blk: Block, func):
for stmt in blk.stmts:
- func(stmt)
+ Visit(stmt, func)
@Visit.register
def _(jmp: If, func):
func(jmp.cond)
- func(jmp.then)
+ Visit(jmp.then, func)
if jmp.orelse is not None:
- func(jmp.orelse)
+ Visit(jmp.orelse, func)
@Visit.register
def _(loop: For, func):
func(loop.init)
func(loop.cond)
func(loop.step)
- func(loop.body)
+ Visit(loop.body, func)
@Visit.register
def _(ret: Return, func):
@@ -889,7 +893,7 @@ def _(s: Empty, func):
@Make.register
def _(blk: Block, func):
- return Block(*[func(stmt) for stmt in blk.Stmts])
+ return Block(*(Make(stmt, func) for stmt in blk.stmts))
@Make.register
def _(loop: For, func):
@@ -897,7 +901,7 @@ def _(loop: For, func):
@Make.register
def _(jmp: If, func):
- return If(func(jmp.cond), func(jmp.then), func(loop.orelse) if loop.orelse is not None else None)
+ return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func) if jmp.orelse is not None else None)
@Make.register
def _(ret: Return, func):
@@ -1204,7 +1208,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
kernel.execute(
Set(n, EvenTo(n, times)),
For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times),
- body=Expand(loop.body, times)
+ body = Block(*(Make(loop.body, lambda x: Step(x, i)) for i in range(times)))
),
Return(n)
)
@@ -1347,9 +1351,7 @@ def axpby():
F.emit()
print(buffer)
-if __name__ == "__main__":
- emitheader()
-
+def argmax():
F = Func("blas·argmax", Int,
Params(
(Int, "len"), (Ptr(Float64), "x"),
@@ -1370,7 +1372,17 @@ if __name__ == "__main__":
)
)
+ loop, kernel, call = Unroll(loop, 8, "argmax")
+ kernel.emit()
+ emitln(2)
+
F.execute(loop)
F.emit()
print(buffer)
+
+if __name__ == "__main__":
+ emitheader()
+
+ argmax()
+