aboutsummaryrefslogtreecommitdiff
path: root/src/libmath/matrix.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libmath/matrix.c')
-rw-r--r--src/libmath/matrix.c176
1 files changed, 176 insertions, 0 deletions
diff --git a/src/libmath/matrix.c b/src/libmath/matrix.c
new file mode 100644
index 0000000..e8bca0b
--- /dev/null
+++ b/src/libmath/matrix.c
@@ -0,0 +1,176 @@
+#include <u.h>
+#include <libn.h>
+#include <libmath.h>
+
+/* TODO: replace (incrementally) with native C version! */
+#include <vendor/blas/cblas.h>
+#include <vendor/blas/lapacke.h>
+
+// -----------------------------------------------------------------------
+// level 1
+
+error
+la·vecslice(math·Vector *x, int min, int max, int inc)
+{
+ if (max > x->len || min < 0) {
+ errorf("out of bounds: attempted to access vector past length");
+ return 1;
+ }
+ x->len = (max - min) / inc;
+ x->d += x->inc * min;
+ x->inc *= inc;
+
+ return 0;
+}
+
+/* simple blas wrappers */
+void
+la·veccopy(math·Vector *dst, math·Vector *src)
+{
+ return cblas_dcopy(src->len, src->d, src->inc, dst->d, dst->inc);
+}
+
+double
+la·vecnorm(math·Vector *x)
+{
+ return cblas_dnrm2(x->len, x->d, x->inc);
+}
+
+void
+la·vecscale(math·Vector *x, double a)
+{
+ return cblas_dscal(x->len, a, x->d, x->inc);
+}
+
+double
+la·vecdot(math·Vector *x, math·Vector *y)
+{
+ return cblas_ddot(x->len, x->d, x->inc, y->d, y->inc);
+}
+
+// -----------------------------------------------------------------------
+// level 2
+
+error
+la·vecmat(math·Vector *x, math·Matrix *M)
+{
+ if (M->dim[1] != x->len) {
+ errorf("incompatible matrix dimensions");
+ return 1;
+ }
+ if (M->state & ~mat·trans)
+ cblas_dgemv(CblasRowMajor,CblasNoTrans,M->dim[0],M->dim[1],1.,M->d,M->inc,x->d,x->inc,0.,x->d,x->inc);
+ else
+ cblas_dgemv(CblasRowMajor,CblasTrans,M->dim[0],M->dim[1],1.,M->d,M->inc,x->d,x->inc,0.,x->d,x->inc);
+
+ return 0;
+}
+
+// -----------------------------------------------------------------------
+// level 3
+
+void
+la·transpose(math·Matrix *X)
+{
+ int tmp;
+ X->state ^= mat·trans;
+ tmp = X->dim[0], X->dim[0] = X->dim[1], X->dim[1] = tmp;
+}
+
+error
+la·matrow(math·Matrix *X, int r, math·Vector *row)
+{
+ if (r < 0 || r >= X->dim[0]) {
+ errorf("out of bounds");
+ return 1;
+ }
+
+ row->len = X->dim[1];
+ row->inc = 1;
+ row->d = X->d + X->dim[1] * r;
+
+ return 0;
+}
+
+error
+la·matcol(math·Matrix *X, int c, math·Vector *col)
+{
+ if (c < 0 || c >= X->dim[1]) {
+ errorf("out of bounds");
+ return 1;
+ }
+
+ col->len = X->dim[0];
+ col->inc = X->dim[1];
+ col->d = X->d + c;
+
+ return 0;
+}
+
+error
+la·matslice(math·Matrix *X, int r[3], int c[3])
+{
+ /* TODO */
+ return 0;
+}
+
+error
+la·eig(math·Matrix *X)
+{
+
+}
+
+/* X = A*B */
+error
+la·matmul(math·Matrix *X, math·Matrix *A, math·Matrix *B)
+{
+ if (A->dim[1] != B->dim[0]) {
+ errorf("number of interior dimensions of A '%d' not equal to that of B '%d'", A->dim[1], B->dim[0]);
+ return 1;
+ }
+ if (X->dim[0] != A->dim[0]) {
+ errorf("number of exterior dimensions of X '%d' not equal to that of A '%d'", X->dim[0], A->dim[0]);
+ return 1;
+ }
+ if (X->dim[1] != B->dim[1]) {
+ errorf("number of exterior dimensions of X '%d' not equal to that of B '%d'", X->dim[1], B->dim[1]);
+ return 1;
+ }
+
+ if (X->state & ~mat·trans)
+ if (A->state & ~mat·trans)
+ cblas_dgemm(CblasRowMajor,CblasNoTrans,CblasNoTrans,A->dim[0],B->dim[1],A->dim[1],1.,A->d,A->inc,B->d,B->inc,0.,X->d,X->inc);
+ else
+ cblas_dgemm(CblasRowMajor,CblasNoTrans,CblasTrans,A->dim[0],B->dim[1],A->dim[1],1.,A->d,A->inc,B->d,B->inc,0.,X->d,X->inc);
+ else
+ if (A->state & ~mat·trans)
+ cblas_dgemm(CblasRowMajor,CblasTrans,CblasNoTrans,A->dim[0],B->dim[1],A->dim[1],1.,A->d,A->inc,B->d,B->inc,0.,X->d,X->inc);
+ else
+ cblas_dgemm(CblasRowMajor,CblasTrans,CblasTrans,A->dim[0],B->dim[1],A->dim[1],1.,A->d,A->inc,B->d,B->inc,0.,X->d,X->inc);
+
+ return 0;
+}
+
+/*
+ * solves A*X=B
+ * pass in B via X
+ */
+error
+la·solve(math·Matrix *X, math·Matrix *A)
+{
+ error err;
+ int n, *ipv;
+ static int buf[512];
+ if (n = A->dim[0], n < arrlen(buf)) {
+ ipv = buf;
+ n = 0;
+ } else
+ ipv = malloc(n*sizeof(*ipv));
+
+ /* TODO: utilize more specific regimes if applicable */
+ err = LAPACKE_dgesv(LAPACK_ROW_MAJOR,A->dim[0],X->dim[1],A->d,A->inc,ipv,X->d,X->inc);
+
+ if (n)
+ free(ipv);
+ return err;
+}