aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-10 14:23:32 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-10 14:23:32 -0700
commit9828864af9fb9d39d36e4f3811ce8c3283e8433d (patch)
treea0dde6aee34b50d82d89ecb12c7c4b1b6d5949cc
parent65fe4a1ddd852c9c702ae008c3b880a20b84d8e9 (diff)
simple python library to directly write a C AST
-rw-r--r--lib/c.py415
1 files changed, 415 insertions, 0 deletions
diff --git a/lib/c.py b/lib/c.py
new file mode 100644
index 0000000..4f0ae29
--- /dev/null
+++ b/lib/c.py
@@ -0,0 +1,415 @@
+#! /bin/python3
+from __future__ import annotations
+
+import os
+import sys
+
+from abc import ABC, abstractmethod
+from enum import Enum
+from typing import List
+
+# ------------------------------------------------------------------------
+# String buffer
+
+level = 0
+buffer = ""
+
+def emit(s: str):
+ global buffer
+ buffer += s
+
+def emitln(n=1):
+ global buffer
+ buffer += "\n"*n + (" " * level)
+
+def enter_scope():
+ global level
+ emit("{")
+ level += 1
+ emitln()
+
+def exits_scope():
+ global level
+ level -= 1
+ emitln()
+ emit("}")
+
+# ------------------------------------------------------------------------
+# Simple C AST
+# TODO: Type checking
+
+# Abstract class everything will derive from
+# All AST nodes will have an "emit" function that outputs formatted C code
+class Emitter(ABC):
+ @abstractmethod
+ def emit(self):
+ pass
+
+# ------------------------------------------
+# Representation of a C type
+class TypeKind(Enum):
+ Void = "void"
+ Error = "error"
+
+ Int = "int"
+ Int32 = "int32"
+ Int64 = "int64"
+
+ # vectorized variants
+ Int32x4 = "__mm128i"
+ Int32x8 = "__mm256i"
+ Int64x2 = "__mm128i"
+ Int64x4 = "__mm256i"
+
+ Float32 = "float"
+ Float64 = "double"
+
+ # vectorized variants
+ Float32x4 = "__m128"
+ Float32x8 = "__mm256"
+ Float64x2 = "__mm128d"
+ Float64x4 = "__mm256d"
+
+ Pointer = "pointer"
+ Struct = "struct"
+ Enum = "enum"
+ Union = "union"
+
+class Type(Emitter):
+ def emit(self):
+ pass
+
+ def emitspec(self, var):
+ pass
+
+class Base(Type):
+ def __init__(self, name: str):
+ self.name = name
+
+ def emit(self):
+ emit(self.name)
+
+ def emitspec(self, ident):
+ emit(f"{ident}")
+
+# Machine primitive types
+Void = Base("void")
+Error = Base("error")
+
+Int = Base("int")
+Int32 = Base("int32")
+Int64 = Base("int64")
+Int32x4 = Base("__mm128i")
+Int32x8 = Base("__mm256i")
+Int64x2 = Base("__mm128i")
+Int64x4 = Base("__mm256i")
+
+Float32 = Base("float")
+Float64 = Base("double")
+Float32x4 = Base("__m128")
+Float32x8 = Base("__mm256")
+Float64x2 = Base("__m128d")
+Float64x4 = Base("__mm256d")
+
+class Ptr(Type):
+ def __init__(self, to: Type):
+ self.to = to
+
+ def emit(self):
+ self.to.emit()
+
+ def emitspec(self, ident):
+ emit("*")
+ self.to.emitspec(ident)
+
+class Array(Type):
+ def __init__(self, base: Type, len: int):
+ self.base = base
+ self.len = len
+
+ def emit(self):
+ self.base.emit()
+
+ def emitspec(self, ident):
+ self.base.emitspec(ident)
+ emit(f"[{self.len}]")
+
+# TODO: Typedefs...
+
+# ------------------------------------------
+# C expressions
+class Expr(Emitter):
+ def emit():
+ pass
+
+# Binary operators
+class BinOp(Expr):
+ def __init__(self, left: Expr, right: Expr):
+ self.l = left
+ self.r = right
+
+ def emit(self):
+ pass
+
+# TODO: check types if they are vectorized and emit correct intrinsic
+class Add(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" + ")
+ self.r.emit()
+
+class Sub(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" - ")
+ self.r.emit()
+
+class Mul(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" * ")
+ self.r.emit()
+
+class Div(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" / ")
+ self.r.emit()
+
+class Gt(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" > ")
+ self.r.emit()
+
+class Lt(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" < ")
+ self.r.emit()
+
+class Ge(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" >= ")
+ self.r.emit()
+
+class Le(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" <= ")
+ self.r.emit()
+
+class Eq(BinOp):
+ def emit(self):
+ self.l.emit()
+ emit(f" == ")
+ self.r.emit()
+
+# Assignment (stores)
+class Assign(Expr):
+ def __init__(self, lhs: Expr, rhs: Expr):
+ self.lhs = lhs
+ self.rhs = rhs
+
+class Mv(Assign):
+ def emit(self):
+ self.lhs.emit()
+ emit(f" = ")
+ self.rhs.emit()
+
+class AddMv(Assign):
+ def emit(self):
+ self.lhs.emit()
+ emit(f" += ")
+ self.rhs.emit()
+
+class SubMv(Assign):
+ def emit(self):
+ self.lhs.emit()
+ emit(f" -= ")
+ self.rhs.emit()
+
+class MulMv(Assign):
+ def emit(self):
+ self.lhs.emit()
+ emit(f" *= ")
+ self.rhs.emit()
+
+class DivMv(Assign):
+ def emit(self):
+ self.lhs.emit()
+ emit(f" /= ")
+ self.rhs.emit()
+
+class Comma(Expr):
+ def __init__(self, x: Expr, next: Expr):
+ self.expr = (x, next)
+
+ def emit(self):
+ self.expr[0].emit()
+ emit(", ")
+ self.expr[1].emit()
+
+# Common assignments are kept globally
+# C statements
+
+class Stmt(Emitter):
+ def emit(self):
+ emit(";")
+
+class Empty(Stmt):
+ def __init__(self):
+ pass
+ def emit(self):
+ super(Empty, self).emit()
+
+class Block(Stmt):
+ def __init__(self, stmts: List[Stmt]):
+ self.stmts = stmts
+
+ def emit(self):
+ enter_scope()
+ for stmt in self.stmts:
+ stmt.emit()
+ exits_scope()
+ super(Block, self).emit()
+
+class For(Stmt):
+ def __init__(self, init: Expr, cond: Expr, step: Expr, body: Stmt):
+ self.init = init
+ self.cond = cond
+ self.step = step
+ self.body = body
+
+ def emit(self):
+ emit("for (")
+ self.init.emit()
+ emit(";")
+ self.cond.emit()
+ emit(";")
+ self.step.emit()
+ emit(")")
+
+ self.body.emit()
+ super(For, self).emit()
+
+class Return(Stmt):
+ def __init__(self, val: Expr):
+ self.val = val
+
+ def emit(self):
+ emitln()
+ emit("return ")
+ self.val.emit()
+ super(Return, self).emit()
+
+class StmtExpr(Stmt):
+ def __init__(self, x: Expr):
+ self.x = x
+
+ def emit(self):
+ self.x.emit()
+ super(StmtExpr, self).emit()
+
+class Mem(Enum):
+ Auto = ""
+ Static = "static"
+ Register = "register"
+ Typedef = "typedef"
+ External = "extern"
+
+class Decl(Emitter):
+ def __init__(self):
+ pass
+
+ def emit(self):
+ pass
+
+class Func(Decl):
+ def __init__(self, ident: str, ret: Expr = Void, params: List[Param] = [], vars: List[Var | List[Var]] = [], body: List[Stmt] = []):
+ self.ident = ident
+ self.ret = ret
+ self.params = params
+ self.vars = vars
+ self.stmts = body
+
+ def emit(self):
+ self.ret.emit()
+ emitln()
+ emit(self.ident)
+ emit("(")
+ for i, p in enumerate(self.params):
+ p.emittype()
+ emit(" ")
+ p.emitspec()
+ if i < len(self.params) - 1:
+ emit(", ")
+ emit(")\n")
+
+ enter_scope()
+
+ for var in self.vars:
+ if isinstance(var, list):
+ v = var[0]
+ v.emittype()
+ emit(" ")
+ v.emitspec()
+ for v in var[1:]:
+ emit(", ")
+ v.emitspec()
+ else:
+ var.emittype()
+ emit(" ")
+ var.emitspec()
+
+ emit(";")
+ emitln()
+
+ emitln()
+ for stmt in self.stmts:
+ stmt.emit()
+
+ exits_scope()
+
+ def declare(self, var: Var, *vars: List[Var]):
+ self.vars.append(var)
+ self.vars.extend(vars)
+
+ def instruct(self, stmt: Stmt | Expr, *args: List[Stmt | Expr]):
+ def push(n):
+ if isinstance(n, Stmt):
+ self.stmts.append(n)
+ elif isinstance(n, Expr):
+ self.stmts.append(StmtExpr(n))
+ else:
+ raise TypeError("unrecognized type for function")
+
+ push(stmt)
+ for arg in args:
+ push(arg)
+
+class Var(Decl):
+ def __init__(self, type: Type, name: str, storage: Mem = Mem.Auto):
+ self.name = name
+ self.type = type
+ self.storage = storage
+
+ def emit(self):
+ emit(f"{self.name}")
+
+ def emittype(self):
+ if self.storage != Mem.Auto:
+ emit(self.storage.value)
+ emit(" ")
+ self.type.emit()
+
+ def emitspec(self):
+ self.type.emitspec(self.name)
+
+class Param(Var):
+ def __init__(self, type: Type, name: str):
+ return super(Param, self).__init__(type, name, mem.Auto)
+
+# ------------------------------------------------------------------------
+# AST modification functions