aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-11 19:51:14 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-11 19:51:14 -0700
commitb4cfd0d257690aebd2df2f6ce00fb2158b3b61d5 (patch)
treedaec6d931914fa2c9bf1cdbf8029a53da8372984
parentfc8fe87f6691cce902a3a0e2bf1080ae9b553db2 (diff)
feat: added if statement
-rw-r--r--lib/c.py112
1 files changed, 78 insertions, 34 deletions
diff --git a/lib/c.py b/lib/c.py
index f1aab4f..0830dc5 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -240,8 +240,6 @@ def emit·copy256(l, r):
r.emit()
emit("))")
-
-
# TODO: Typedefs...
# ------------------------------------------
@@ -273,11 +271,18 @@ class Ident(Expr):
self.name = var.name
self.var = var
+ def __str__(self):
+ return str(self.name)
+
+ def __hash__(self):
+ return hash(self.name)
+
+ def __eq__(self, other):
+ return type(self) == type(other) and self.name == other.name
+
def emit(self):
emit(f"{self.name}")
- def __str__(self):
- return str(self.name)
# Unary operators
class UnaryOp(Expr):
@@ -295,7 +300,6 @@ class Deref(UnaryOp):
def emit(self):
kind = GetType(self.x)
if kind in self.method:
- print("emitting")
self.method[kind](self.x)
else:
emit("*")
@@ -540,18 +544,33 @@ class Empty(Stmt):
super(Empty, self).emit()
class Block(Stmt):
- def __init__(self, stmts: List[Stmt]):
- self.stmts = stmts
+ def __init__(self, *stmts: List[Stmt]):
+ self.stmts = list(stmts)
def emit(self):
enter_scope()
for i, stmt in enumerate(self.stmts):
stmt.emit()
+ emit(";")
if i < numel(self.stmts) - 1:
emitln()
exits_scope()
+class If(Stmt):
+ def __init__(self, cond: Expr | List[Expr], then: Stmt | List[Stmt], orelse: Stmt | List[Stmt] | None = None):
+ self.cond = cond if isinstance(cond, Expr) else ExprList(*cond)
+ self.then = Block(then) if isinstance(then, list) else then
+ self.orelse = Block(orelse) if isinstance(orelse, list) else orelse
+
+ def emit(self):
+ emit("if (")
+ self.cond.emit()
+ emit(") ")
+ self.then.emit()
+ if self.orelse is not None:
+ self.orelse.emit()
+
class For(Stmt):
def __init__(self, init: Expr | List[Expr], cond: Expr | List[Expr], step: Expr | List[Expr], body: Stmt| None = None):
self.init = init if isinstance(init, Expr) else ExprList(*init)
@@ -700,6 +719,9 @@ class Func(Decl):
def variables(self, *idents: List[str]) -> List[Expr]:
vars = {v.name : v for v in self.vars + self.params}
+
+ if numel(idents) == 1:
+ return Ident(vars[idents[0]])
return [Ident(vars[ident]) for ident in idents]
class Var(Decl):
@@ -746,7 +768,7 @@ def Expand(s: StmtExpr, times: int) -> Block:
if not isinstance(s, StmtExpr):
raise TypeError(f"{type(x)} not supported by Expand operation")
- return Block([StmtExpr(Step(s.x, i)) for i in range(times)])
+ return Block(*[StmtExpr(Step(s.x, i)) for i in range(times)])
def EvenTo(x: Var, n: int) -> Var:
return And(x, Negate(I(n-1)))
@@ -860,7 +882,7 @@ def _(s: Empty, func):
@Make.register
def _(blk: Block, func):
- return Block([func(stmt) for stmt in blk.Stmts])
+ return Block(*[func(stmt) for stmt in blk.Stmts])
@Make.register
def _(loop: For, func):
@@ -1010,7 +1032,7 @@ def _(x: BinaryOp, func):
@singledispatch
def Filter(x: Expr, cond, results: List[Expr]):
- raise TypeError(f"{type(s)} not supported by Filter operation")
+ raise TypeError(f"{type(x)} not supported by Filter operation")
@Filter.register
def _(x: Empty, cond, results: List[Expr]):
@@ -1141,7 +1163,7 @@ def AddrAccessed(stmt: Stmt):
if vars[0] in scalars:
scalars.remove(vars[0])
- return frozenset(scalars), vectors
+ return set(scalars), vectors
# ------------------------------------------
# Large scale functions
@@ -1152,7 +1174,7 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
raise TypeError(f"{type(loop.cond)} not supported in loop unrolling")
# pull off needed features of the loop
- it = loop.init.lhs.var
+ i = loop.init.lhs.var
vars = VarsUsed(loop.body)
params = [v for v in vars if type(v) == Param]
@@ -1165,11 +1187,12 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
params = [n] + params
kernel = Func(f"{name}_kernel{times}", Int, params, stacks)
+ i = kernel.variables("i")
body = loop.body
kernel.execute(
Set(n, EvenTo(n, times)),
- For(Set(it, I(0)), LT(it, n), Repeat(loop.step, times),
+ For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times),
body=Expand(loop.body, times)
),
Return(n)
@@ -1187,7 +1210,10 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params)
for stmt in func.stmts:
if IsLoop(stmt):
- loop = For(stmt.init, stmt.cond, stmt.step)
+
+ loop = For(stmt.init, stmt.cond, stmt.step)
+ iterator = set([stmt.init.lhs])
+
body = stmt.body
if type(body) != Block:
print("could not vectorize loop, skipping")
@@ -1206,7 +1232,8 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
vectors = []
for j in range(4):
s, v = AddrAccessed(instr[i+j])
- scalars.append(s)
+ s -= iterator # TODO: This is hacky
+ scalars.append(frozenset(s))
vectors.append(v)
# Test if code in uniform to allow for vectorization
@@ -1217,7 +1244,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
for v, idx in vectors[j].items():
if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1:
print(f"{delta}")
- print(f"{sympy.simplify(delta)}")
raise ValueError("non uniform vector accesses in consecutive line. can not vectorize")
# If we made it to here, we have passed all checks. vectorize!
@@ -1235,12 +1261,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
for v in vectors[0].keys():
v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256")
- # Necessary loads into registers
- # NOTE: Generalization of a Deref
- # for v in vectors[0].keys():
- # load = SIMDLoadAt(symtab[v.name], Add(v, I(i)))
- # loop.execute(load)
-
# IMPORTANT: We do a post-order traversal.
# We transforms leaves (identifiers) first and then move back up the root
def translate(x):
@@ -1260,8 +1280,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
if type(r) == Ptr:
x.r = Deref(x.r)
- # if type(x) == Index and x.x in set(symtab.values()):
- # return x.x
return x
loop.execute(Make(instr[i], lambda node: Transform(node, translate)))
@@ -1279,12 +1297,10 @@ def Strided(func: Func) -> Func:
# ------------------------------------------------------------------------
# Point of testing
-if __name__ == "__main__":
- emitheader()
-
- F = Func("blas·axpy", Void,
+def axpby():
+ F = Func("blas·axpby", Void,
Params(
- (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
+ (Int, "len"), (Float64, "a"), (Ptr(Float64), "x"), (Float64, "b"), (Ptr(Float64), "y")
),
Vars(
(Int, "i"), #(Float64, "tmp")
@@ -1294,19 +1310,21 @@ if __name__ == "__main__":
# F.declare( ... )
# TODO: Increase ergonomics here...
- len, x, y, i, a = F.variables("len", "x", "y", "i", "a")
+ len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b")
loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
# body = Swap(Index(x, I(0)), Index(y, I(0)), tmp)
loop.execute(
Set(Index(y, i),
- Add(Index(y, i),
- Mul(a, Index(x, i))
+ Mul(b,
+ Add(Index(y, i),
+ Mul(a, Index(x, i)),
+ )
)
)
)
- rem, kernel, call = Unroll(loop, 16, "swap")
+ rem, kernel, call = Unroll(loop, 16, "axpy")
kernel.emit()
emitln(2)
@@ -1315,7 +1333,33 @@ if __name__ == "__main__":
emitln(2)
F.execute(Set(i, call), rem)
-
F.emit()
+ print(buffer)
+
+if __name__ == "__main__":
+ emitheader()
+
+ F = Func("blas·argmax", Int,
+ Params(
+ (Int, "len"), (Ptr(Float64), "x"),
+ ),
+ Vars(
+ (Int, "i"), (Int, "ix"), (Float64, "max")
+ )
+ )
+ len, x, i, ix, max = F.variables("len", "x", "i", "ix", "max")
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ If(GT(Index(x, i), max),
+ Block(
+ Set(ix, i),
+ Set(max, Index(x, i)),
+ )
+ )
+ )
+
+ F.execute(loop)
+
+ F.emit()
print(buffer)