aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-11 19:10:29 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-11 19:10:29 -0700
commitfc8fe87f6691cce902a3a0e2bf1080ae9b553db2 (patch)
tree3028cabe7758fc9140ab656f9046fe982ce49d07
parentb45f53b681d7bef4f1e96ee27f80c40dbf67573d (diff)
feat: introduced dependency on sympy to compute/simplify compile time constants
-rw-r--r--lib/c.py44
1 files changed, 25 insertions, 19 deletions
diff --git a/lib/c.py b/lib/c.py
index f85e7a7..f1aab4f 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -4,6 +4,8 @@ from __future__ import annotations
import os
import sys
+import sympy
+
from enum import Enum
from typing import List, Tuple, Dict
from functools import singledispatch
@@ -274,6 +276,9 @@ class Ident(Expr):
def emit(self):
emit(f"{self.name}")
+ def __str__(self):
+ return str(self.name)
+
# Unary operators
class UnaryOp(Expr):
def __init__(self, x: Expr):
@@ -448,14 +453,9 @@ class Set(Assign):
class AddSet(Assign):
def emit(self):
- lhs = GetType(self.lhs)
- rhs = GetType(self.rhs)
- if (lhs, rhs) in Set.method:
- Set(self.lhs, Add(self.lhs, self.rhs)).emit()
- else:
- self.lhs.emit()
- emit(f" += ")
- self.rhs.emit()
+ self.lhs.emit()
+ emit(f" += ")
+ self.rhs.emit()
class SubSet(Assign):
def emit(self):
@@ -814,7 +814,7 @@ def _(x: Deref, i: int):
@Step.register
def _(ix: Index, i: int):
- return Index(ix.x, I(ix.i + i))
+ return Index(ix.x, Add(ix.i, I(i)))
# ------------------------------------------
# bfs search on statements in ast
@@ -1077,7 +1077,12 @@ def _(x: Empty):
pass
@Eval.register(Ident)
-@Eval.register(Literal)
+@Eval.register(S)
+def _(x):
+ return sympy.symbols(f"{x}")
+
+@Eval.register(I)
+@Eval.register(F)
def _(x):
return x
@@ -1094,14 +1099,13 @@ def _(x: Dec):
def _(op: Add):
l = Eval(op.l)
r = Eval(op.r)
- return l + r
+ return sympy.simplify(l + r)
@Eval.register
def _(op: Sub):
l = Eval(op.l)
r = Eval(op.r)
- # TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x))
- return l - r
+ return sympy.simplify(l - r)
@Eval.register
def _(s: Assign):
@@ -1211,7 +1215,9 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
for j in range(1, 4):
for v, idx in vectors[j].items():
- if Eval(Sub(idx, vectors[j-1][v])) != 1:
+ if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1:
+ print(f"{delta}")
+ print(f"{sympy.simplify(delta)}")
raise ValueError("non uniform vector accesses in consecutive line. can not vectorize")
# If we made it to here, we have passed all checks. vectorize!
@@ -1290,17 +1296,17 @@ if __name__ == "__main__":
# TODO: Increase ergonomics here...
len, x, y, i, a = F.variables("len", "x", "y", "i", "a")
- loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)])
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
# body = Swap(Index(x, I(0)), Index(y, I(0)), tmp)
loop.execute(
- Set(Index(y, I(0)),
- Add(Index(y, I(0)),
- Mul(a, Index(x, I(0)))
+ Set(Index(y, i),
+ Add(Index(y, i),
+ Mul(a, Index(x, i))
)
)
)
- rem, kernel, call = Unroll(loop, 8, "swap")
+ rem, kernel, call = Unroll(loop, 16, "swap")
kernel.emit()
emitln(2)