From fc8fe87f6691cce902a3a0e2bf1080ae9b553db2 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Mon, 11 May 2020 19:10:29 -0700 Subject: feat: introduced dependency on sympy to compute/simplify compile time constants --- lib/c.py | 44 +++++++++++++++++++++++++------------------- 1 file changed, 25 insertions(+), 19 deletions(-) (limited to 'lib') diff --git a/lib/c.py b/lib/c.py index f85e7a7..f1aab4f 100644 --- a/lib/c.py +++ b/lib/c.py @@ -4,6 +4,8 @@ from __future__ import annotations import os import sys +import sympy + from enum import Enum from typing import List, Tuple, Dict from functools import singledispatch @@ -274,6 +276,9 @@ class Ident(Expr): def emit(self): emit(f"{self.name}") + def __str__(self): + return str(self.name) + # Unary operators class UnaryOp(Expr): def __init__(self, x: Expr): @@ -448,14 +453,9 @@ class Set(Assign): class AddSet(Assign): def emit(self): - lhs = GetType(self.lhs) - rhs = GetType(self.rhs) - if (lhs, rhs) in Set.method: - Set(self.lhs, Add(self.lhs, self.rhs)).emit() - else: - self.lhs.emit() - emit(f" += ") - self.rhs.emit() + self.lhs.emit() + emit(f" += ") + self.rhs.emit() class SubSet(Assign): def emit(self): @@ -814,7 +814,7 @@ def _(x: Deref, i: int): @Step.register def _(ix: Index, i: int): - return Index(ix.x, I(ix.i + i)) + return Index(ix.x, Add(ix.i, I(i))) # ------------------------------------------ # bfs search on statements in ast @@ -1077,7 +1077,12 @@ def _(x: Empty): pass @Eval.register(Ident) -@Eval.register(Literal) +@Eval.register(S) +def _(x): + return sympy.symbols(f"{x}") + +@Eval.register(I) +@Eval.register(F) def _(x): return x @@ -1094,14 +1099,13 @@ def _(x: Dec): def _(op: Add): l = Eval(op.l) r = Eval(op.r) - return l + r + return sympy.simplify(l + r) @Eval.register def _(op: Sub): l = Eval(op.l) r = Eval(op.r) - # TODO: This won't work in general (if we have things like sizeof(x) + 1 - sizeof(x)) - return l - r + return sympy.simplify(l - r) @Eval.register def _(s: Assign): @@ -1211,7 +1215,9 @@ def Vectorize(func: Func, isa: SIMD) -> Func: for j in range(1, 4): for v, idx in vectors[j].items(): - if Eval(Sub(idx, vectors[j-1][v])) != 1: + if (delta := (Eval(Sub(idx, vectors[j-1][v])))) != 1: + print(f"{delta}") + print(f"{sympy.simplify(delta)}") raise ValueError("non uniform vector accesses in consecutive line. can not vectorize") # If we made it to here, we have passed all checks. vectorize! @@ -1290,17 +1296,17 @@ if __name__ == "__main__": # TODO: Increase ergonomics here... len, x, y, i, a = F.variables("len", "x", "y", "i", "a") - loop = For(Set(i, I(0)), LT(i, len), [Inc(i), Inc(x), Inc(y)]) + loop = For(Set(i, I(0)), LT(i, len), [Inc(i)]) # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp) loop.execute( - Set(Index(y, I(0)), - Add(Index(y, I(0)), - Mul(a, Index(x, I(0))) + Set(Index(y, i), + Add(Index(y, i), + Mul(a, Index(x, i)) ) ) ) - rem, kernel, call = Unroll(loop, 8, "swap") + rem, kernel, call = Unroll(loop, 16, "swap") kernel.emit() emitln(2) -- cgit v1.2.1