aboutsummaryrefslogtreecommitdiff
path: root/include
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-08 16:45:52 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-08 16:45:52 -0700
commit327ca20a2a89d2408b53ff7854982560304cb76c (patch)
tree4fc3231b96b65e6f15d3852e3b6e4f3109b1f0e7 /include
parentd5e3041d34e4615ea8f81bd39a2a9231ef38253f (diff)
added more level 2 and 3 functions to blas implementation
Diffstat (limited to 'include')
-rw-r--r--include/libmath.h65
1 files changed, 64 insertions, 1 deletions
diff --git a/include/libmath.h b/include/libmath.h
index ec15f6f..b0ae434 100644
--- a/include/libmath.h
+++ b/include/libmath.h
@@ -130,4 +130,67 @@ double math·trunc(double);
float math·truncf(float);
// -----------------------------------------------------------------------
-// linear algebra
+// basic linear algebra compute kernels
+
+// TODO: think of better names
+enum
+{
+ blas·LowerTri = 1u,
+ blas·Transpose = 2u,
+ blas·ConjTranspose = 4u,
+ blas·DiagOnes = 8u,
+ blas·LeftSide = 16u,
+};
+
+typedef uint32 blas·Flag;
+
+/* level 1 */
+void blas·rot(int len, double *x, double *y, double cos, double sin);
+void blas·rotg(double *a, double *b, double *cos, double *sin);
+error blas·rotm(int len, double *x, double *y, double p[5]);
+void blas·scale(int len, double *x, double a);
+void blas·copy(int len, double *x, double *y);
+void blas·swap(int len, double *x, double *y);
+void blas·axpy(int len, double a, double *x, double *y);
+double blas·dot(int len, double *x, double *y);
+int blas·argmax(int len, double *x);
+int blas·argmin(int len, double *x);
+
+/* level 2 */
+void blas·tpmv(blas·Flag f, int n, double *m, double *x);
+void blas·tpsv(blas·Flag f, int n, double *m, double *x);
+error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ;
+void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m);
+void blas·her(int n, double a, double *x, double *m);
+void blas·syr(int nrow, int ncol, double a, double *x, double *m);
+
+/* level 3 */
+void blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3);
+void blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2);
+void blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2);
+
+// -----------------------------------------------------------------------
+// higher level linear algebra
+
+struct linalg·Header
+{
+ void *h;
+ mem·Allocator heap;
+
+ double *data;
+};
+
+typedef struct math·Vector
+{
+ int len;
+ struct linalg·Header;
+} math·Vector;
+
+typedef struct math·Matrix
+{
+ int dim[2];
+ blas·Flag kind;
+ struct linalg·Header;
+} math·Matrix;
+
+// TODO: tensor ala numpy