From 327ca20a2a89d2408b53ff7854982560304cb76c Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Fri, 8 May 2020 16:45:52 -0700 Subject: added more level 2 and 3 functions to blas implementation --- include/libmath.h | 65 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- 1 file changed, 64 insertions(+), 1 deletion(-) (limited to 'include') diff --git a/include/libmath.h b/include/libmath.h index ec15f6f..b0ae434 100644 --- a/include/libmath.h +++ b/include/libmath.h @@ -130,4 +130,67 @@ double math·trunc(double); float math·truncf(float); // ----------------------------------------------------------------------- -// linear algebra +// basic linear algebra compute kernels + +// TODO: think of better names +enum +{ + blas·LowerTri = 1u, + blas·Transpose = 2u, + blas·ConjTranspose = 4u, + blas·DiagOnes = 8u, + blas·LeftSide = 16u, +}; + +typedef uint32 blas·Flag; + +/* level 1 */ +void blas·rot(int len, double *x, double *y, double cos, double sin); +void blas·rotg(double *a, double *b, double *cos, double *sin); +error blas·rotm(int len, double *x, double *y, double p[5]); +void blas·scale(int len, double *x, double a); +void blas·copy(int len, double *x, double *y); +void blas·swap(int len, double *x, double *y); +void blas·axpy(int len, double a, double *x, double *y); +double blas·dot(int len, double *x, double *y); +int blas·argmax(int len, double *x); +int blas·argmin(int len, double *x); + +/* level 2 */ +void blas·tpmv(blas·Flag f, int n, double *m, double *x); +void blas·tpsv(blas·Flag f, int n, double *m, double *x); +error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ; +void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m); +void blas·her(int n, double a, double *x, double *m); +void blas·syr(int nrow, int ncol, double a, double *x, double *m); + +/* level 3 */ +void blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3); +void blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); +void blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); + +// ----------------------------------------------------------------------- +// higher level linear algebra + +struct linalg·Header +{ + void *h; + mem·Allocator heap; + + double *data; +}; + +typedef struct math·Vector +{ + int len; + struct linalg·Header; +} math·Vector; + +typedef struct math·Matrix +{ + int dim[2]; + blas·Flag kind; + struct linalg·Header; +} math·Matrix; + +// TODO: tensor ala numpy -- cgit v1.2.1