aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-12 15:34:15 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-12 15:34:15 -0700
commit9a7c2b4f12b976ee080412c3b685fb49e3d20d34 (patch)
treedc65eeef3497c6badb5825c0678dbdc5c9fca2d6
parent4e4a3ae1b19611f0367624f541efe91aff570fab (diff)
checkpoint: plan to simplify vectorization code
-rw-r--r--lib/c.py139
1 files changed, 110 insertions, 29 deletions
diff --git a/lib/c.py b/lib/c.py
index d88f7ab..25c02eb 100644
--- a/lib/c.py
+++ b/lib/c.py
@@ -7,7 +7,7 @@ import sys
import sympy
from enum import Enum
-from typing import List, Tuple, Dict
+from typing import Callable, List, Tuple, Dict
from functools import singledispatch
numel = len
@@ -249,7 +249,7 @@ class Expr(Emitter):
pass
# Literals
-class Literal():
+class Literal(Expr):
def emit(self):
emit(f"{self}")
@@ -738,10 +738,18 @@ class Var(Decl):
def emitspec(self):
self.type.emitspec(self.name)
+ @classmethod
+ def copy(cls, other: Var):
+ return cls(other.type, other.name, other.storage)
+
class Param(Var):
def __init__(self, type: Type, name: str):
return super(Param, self).__init__(type, name, Mem.Auto)
+ @classmethod
+ def copy(cls, other: Param):
+ return cls(other.type, other.name)
+
def Params(*ps: List[Tuple(Type, str)]) -> List[Param]:
return [Param(p[0], p[1]) for p in ps]
@@ -897,11 +905,14 @@ def _(blk: Block, func):
@Make.register
def _(loop: For, func):
- return For(func(loop.init), func(loop.cond), func(loop.step), func(loop.body))
+ return For(func(loop.init), func(loop.cond), func(loop.step), Make(loop.body, func))
@Make.register
def _(jmp: If, func):
- return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func) if jmp.orelse is not None else None)
+ if jmp.orelse is not None:
+ return If(func(jmp.cond), Make(jmp.then, func), Make(jmp.orelse, func))
+ else:
+ return If(func(jmp.cond), Make(jmp.then, func))
@Make.register
def _(ret: Return, func):
@@ -1002,7 +1013,7 @@ def _(x: Call):
# Returns a brand new tree structure
@singledispatch
-def Transform(x: Expr, func):
+def Transform(x: Expr, func: Callable[Expr, Expr]):
raise TypeError(f"{type(x)} not supported by Transform operation")
@Transform.register
@@ -1192,29 +1203,59 @@ def Unroll(loop: For, times: int, name: str, vars: List[vars] = []) -> (For, Fun
i = loop.init.lhs.var
vars = VarsUsed(loop.body)
- params = [v for v in vars if type(v) == Param]
- stacks = [v for v in vars if type(v) == Var]
+ def asvector(v):
+ if v != i:
+ return Var(Array(v.type, times), v.name, v.storage)
+ else:
+ return Var.copy(v)
+
+ param = [Param(Ptr(i.type), f"{i.name}p")] + [Param.copy(v) for v in vars if type(v) == Param]
+ stack = {v.name: asvector(v) for v in vars if type(v) == Var}
# TODO: More sophisticated type checking
n = loop.cond.r.var
if (type(n) != Param):
raise TypeError(f"{type(n)} not implemented yet")
- params = [n] + params
- kernel = Func(f"{name}_kernel{times}", Int, params, stacks)
- i = kernel.variables("i")
+ kernel = Func(f"{name}_kernel{times}", Void, param, list(stack.values()))
body = loop.body
+ itor, itorp = kernel.variables("i", "ip")
+
+ def mkarray(x: Expr, times: int):
+ if isinstance(x, Ident):
+ if type(x.var) == Var:
+ x.var = stack[x.name]
+ return x
+
+ def step(x: Expr, times: int):
+ if x == itor:
+ return Add(x, I(times))
+ if isinstance(x, Ident):
+ if type(x.var) == Var:
+ return Index(x, I(times))
+ if not isinstance(x, Expr):
+ raise ValueError(f"panic, hit type {type(x)}")
+
+ return x
kernel.execute(
Set(n, EvenTo(n, times)),
- For(Set(i, I(0)), LT(i, n), Repeat(loop.step, times),
- body = Block(*(Make(loop.body, lambda x: Step(x, i)) for i in range(times)))
+ For(Set(itor, I(0)), LT(itor, n), Repeat(loop.step, times),
+ body = Block(*
+ (Make(
+ Make(loop.body,
+ lambda node:
+ Transform(node, lambda x: mkarray(x, times))),
+ lambda node:
+ Transform(node, lambda x: step(x, i))) for i in range(times)
+ )
+ )
),
- Return(n)
+ Set(Deref(itorp), itor)
)
loop.init = None
- return loop, kernel, Call(kernel, params)
+ return loop, kernel, Call(kernel, param)
# Replaces all vectorizable loops inside Func with vectorized variants
# Returns the new vectorized function (tagged by the SIMD chosen)
@@ -1225,7 +1266,6 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
vfunc = Func(f"{func.name}_{isa.value}", func.ret, func.params)
for stmt in func.stmts:
if IsLoop(stmt):
-
loop = For(stmt.init, stmt.cond, stmt.step)
iterator = set([stmt.init.lhs])
@@ -1240,8 +1280,12 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
if numel(instr) % 4 != 0:
raise ValueError("loop can not be vectorized, instructions can not be globbed equally")
+ class symtab(object):
+ scalar = {}
+ stack = {}
+ addrs = {}
+
# TODO: Allow for non-sequential accesses?
- s_symtab, v_symtab = {}, {}
for i in range(0, numel(instr), 4):
scalars = []
vectors = []
@@ -1266,26 +1310,31 @@ def Vectorize(func: Func, isa: SIMD) -> Func:
if i == 0:
for s in scalars[0]:
if type(s.var) == Param:
- intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128"))
- s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
+ intermediate = vfunc.declare(Var(Float64x2, f"{s.name}128"))
+ symtab.scalar[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
vfunc.execute(Set(intermediate, Ref(s)))
- vfunc.execute(Set(s_symtab[s.name], intermediate))
+ vfunc.execute(Set(symtab.scalar[s.name], intermediate))
else:
- s_symtab[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
+ symtab.scalar[s.name] = vfunc.declare(Var(Float64x4, f"{s.name}256"))
for v in vectors[0].keys():
- v_symtab[v.name] = Var(Ptr(Float64x4), f"{v.name}256")
+ if type(v.var) == Param:
+ symtab.addrs[v.name] = Var(Ptr(Float64x4), f"{v.name}256")
+ else:
+ symtab.stack[v.name] = vfunc.declare(Var(Float64x4, f"{v.name}256"))
# IMPORTANT: We do a post-order traversal.
# We transforms leaves (identifiers) first and then move back up the root
def translate(x):
if isinstance(x, Ident):
- if x.name in s_symtab:
- return s_symtab[x.name]
- elif x.name in v_symtab:
- x.var = v_symtab[x.name]
+ if x.name in symtab.scalar:
+ return symtab.scalar[x.name]
+ elif x.name in symtab.stack:
+ return symtab.stack[x.name]
+ elif x.name in symtab.addrs:
+ x.var = symtab.addrs[x.name]
if isinstance(x, Index):
- if type(x.x) == Ident and x.x.name in v_symtab:
+ if type(x.x) == Ident and (x.x.name in symtab.addrs or x.x.name in symtab.stack):
return Add(x.x, x.i)
if isinstance(x, BinaryOp):
l, r = GetType(x.l), GetType(x.r)
@@ -1312,6 +1361,37 @@ def Strided(func: Func) -> Func:
# ------------------------------------------------------------------------
# Point of testing
+def copy():
+ F = Func("blas·copy", Void,
+ Params(
+ (Int, "len"), (Ptr(Float64), "x"), (Ptr(Float64), "y")
+ ),
+ Vars(
+ (Int, "i"), (Float64, "tmp")
+ )
+ )
+ # could also declare like
+ # F.declare( ... )
+
+ len, x, y, i, tmp = F.variables("len", "x", "y", "i", "tmp")
+
+ loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
+ loop.execute(
+ Swap(Index(x, i), Index(y, i), tmp)
+ )
+
+ rem, kernel, call = Unroll(loop, 8, "copy")
+ kernel.emit()
+ emitln(2)
+
+ avx256kernel = Vectorize(kernel, SIMD.AVX2)
+ avx256kernel.emit()
+ emitln(2)
+
+ F.execute(Set(i, call), rem)
+ F.emit()
+ print(buffer)
+
def axpby():
F = Func("blas·axpby", Void,
Params(
@@ -1328,7 +1408,6 @@ def axpby():
len, x, y, i, a, b = F.variables("len", "x", "y", "i", "a", "b")
loop = For(Set(i, I(0)), LT(i, len), [Inc(i)])
- # body = Swap(Index(x, I(0)), Index(y, I(0)), tmp)
loop.execute(
Set(Index(y, i),
Mul(b,
@@ -1339,7 +1418,7 @@ def axpby():
)
)
- rem, kernel, call = Unroll(loop, 16, "axpy")
+ rem, kernel, call = Unroll(loop, 8, "axpby")
kernel.emit()
emitln(2)
@@ -1384,5 +1463,7 @@ def argmax():
if __name__ == "__main__":
emitheader()
- argmax()
+ copy()
+ # axpby()
+ # argmax()