diff options
author | Nicholas Noll <nbnoll@eml.cc> | 2020-05-08 16:45:52 -0700 |
---|---|---|
committer | Nicholas Noll <nbnoll@eml.cc> | 2020-05-08 16:45:52 -0700 |
commit | 327ca20a2a89d2408b53ff7854982560304cb76c (patch) | |
tree | 4fc3231b96b65e6f15d3852e3b6e4f3109b1f0e7 /include/libmath.h | |
parent | d5e3041d34e4615ea8f81bd39a2a9231ef38253f (diff) |
added more level 2 and 3 functions to blas implementation
Diffstat (limited to 'include/libmath.h')
-rw-r--r-- | include/libmath.h | 65 |
1 files changed, 64 insertions, 1 deletions
diff --git a/include/libmath.h b/include/libmath.h index ec15f6f..b0ae434 100644 --- a/include/libmath.h +++ b/include/libmath.h @@ -130,4 +130,67 @@ double math·trunc(double); float math·truncf(float); // ----------------------------------------------------------------------- -// linear algebra +// basic linear algebra compute kernels + +// TODO: think of better names +enum +{ + blas·LowerTri = 1u, + blas·Transpose = 2u, + blas·ConjTranspose = 4u, + blas·DiagOnes = 8u, + blas·LeftSide = 16u, +}; + +typedef uint32 blas·Flag; + +/* level 1 */ +void blas·rot(int len, double *x, double *y, double cos, double sin); +void blas·rotg(double *a, double *b, double *cos, double *sin); +error blas·rotm(int len, double *x, double *y, double p[5]); +void blas·scale(int len, double *x, double a); +void blas·copy(int len, double *x, double *y); +void blas·swap(int len, double *x, double *y); +void blas·axpy(int len, double a, double *x, double *y); +double blas·dot(int len, double *x, double *y); +int blas·argmax(int len, double *x); +int blas·argmin(int len, double *x); + +/* level 2 */ +void blas·tpmv(blas·Flag f, int n, double *m, double *x); +void blas·tpsv(blas·Flag f, int n, double *m, double *x); +error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ; +void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m); +void blas·her(int n, double a, double *x, double *m); +void blas·syr(int nrow, int ncol, double a, double *x, double *m); + +/* level 3 */ +void blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3); +void blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); +void blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); + +// ----------------------------------------------------------------------- +// higher level linear algebra + +struct linalg·Header +{ + void *h; + mem·Allocator heap; + + double *data; +}; + +typedef struct math·Vector +{ + int len; + struct linalg·Header; +} math·Vector; + +typedef struct math·Matrix +{ + int dim[2]; + blas·Flag kind; + struct linalg·Header; +} math·Matrix; + +// TODO: tensor ala numpy |