aboutsummaryrefslogtreecommitdiff
path: root/sys
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-08 16:45:52 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-08 16:45:52 -0700
commit327ca20a2a89d2408b53ff7854982560304cb76c (patch)
tree4fc3231b96b65e6f15d3852e3b6e4f3109b1f0e7 /sys
parentd5e3041d34e4615ea8f81bd39a2a9231ef38253f (diff)
added more level 2 and 3 functions to blas implementation
Diffstat (limited to 'sys')
-rw-r--r--sys/libmath/blas.c406
-rw-r--r--sys/libmath/linalg.c5
2 files changed, 365 insertions, 46 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index f12e3e2..227715c 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -276,6 +276,7 @@ blas·copy(int len, double *x, double *y)
* y <=> x
*/
+static
void
swap_kernel8_avx2(int n, double *x, double *y)
{
@@ -294,6 +295,7 @@ swap_kernel8_avx2(int n, double *x, double *y)
}
}
+static
void
swap_kernel8(int n, double *x, double *y)
{
@@ -462,6 +464,8 @@ blas·dot(int len, double *x, double *y)
int i, n;
double res;
+ if (len == 0) return 0;
+
n = len & ~15; // neat trick
res = dot_kernel8_fma3(n, x, y);
@@ -483,7 +487,7 @@ argmax_kernel8_avx2(int n, double *x)
register int i, msk, maxidx[4];
__m256d val, cmp, max;
- maxidx[0] = 0, maxidx[1] = 1, maxidx[2] = 2, maxidx[3] = 3;
+ maxidx[0] = 0, maxidx[1] = 0, maxidx[2] = 0, maxidx[3] = 0;
max = _mm256_loadu_pd(x);
for (i = 0; i < n; i += 4) {
@@ -493,102 +497,102 @@ argmax_kernel8_avx2(int n, double *x)
switch (msk) {
case 1:
max = _mm256_blend_pd(max, val, 1);
- maxidx[0] = i+0;
+ maxidx[0] = i;
break;
case 2:
max = _mm256_blend_pd(max, val, 2);
- maxidx[1] = i+1;
+ maxidx[1] = i;
break;
case 3:
max = _mm256_blend_pd(max, val, 3);
- maxidx[0] = i+0;
- maxidx[1] = i+1;
+ maxidx[0] = i;
+ maxidx[1] = i;
break;
case 4:
max = _mm256_blend_pd(max, val, 4);
- maxidx[2] = i+2;
+ maxidx[2] = i;
break;
case 5:
max = _mm256_blend_pd(max, val, 5);
- maxidx[2] = i+2;
- maxidx[0] = i+0;
+ maxidx[2] = i;
+ maxidx[0] = i;
break;
case 6:
max = _mm256_blend_pd(max, val, 6);
- maxidx[2] = i+2;
- maxidx[1] = i+1;
+ maxidx[2] = i;
+ maxidx[1] = i;
break;
case 7:
max = _mm256_blend_pd(max, val, 7);
- maxidx[2] = i+2;
- maxidx[1] = i+1;
- maxidx[0] = i+0;
+ maxidx[2] = i;
+ maxidx[1] = i;
+ maxidx[0] = i;
break;
case 8:
max = _mm256_blend_pd(max, val, 8);
- maxidx[3] = i+3;
+ maxidx[3] = i;
break;
case 9:
max = _mm256_blend_pd(max, val, 9);
- maxidx[3] = i+3;
- maxidx[0] = i+0;
+ maxidx[3] = i;
+ maxidx[0] = i;
break;
case 10:
max = _mm256_blend_pd(max, val, 10);
- maxidx[3] = i+3;
- maxidx[1] = i+1;
+ maxidx[3] = i;
+ maxidx[1] = i;
break;
case 11:
max = _mm256_blend_pd(max, val, 11);
- maxidx[3] = i+3;
- maxidx[1] = i+1;
- maxidx[0] = i+0;
+ maxidx[3] = i;
+ maxidx[1] = i;
+ maxidx[0] = i;
break;
case 12:
max = _mm256_blend_pd(max, val, 12);
- maxidx[3] = i+3;
- maxidx[2] = i+2;
+ maxidx[3] = i;
+ maxidx[2] = i;
break;
case 13:
max = _mm256_blend_pd(max, val, 13);
- maxidx[3] = i+3;
- maxidx[2] = i+2;
- maxidx[0] = i+0;
+ maxidx[3] = i;
+ maxidx[2] = i;
+ maxidx[0] = i;
break;
case 14:
max = _mm256_blend_pd(max, val, 14);
- maxidx[3] = i+3;
- maxidx[2] = i+2;
- maxidx[1] = i+1;
+ maxidx[3] = i;
+ maxidx[2] = i;
+ maxidx[1] = i;
break;
case 15:
max = _mm256_blend_pd(max, val, 15);
- maxidx[3] = i+3;
- maxidx[2] = i+2;
- maxidx[1] = i+1;
- maxidx[0] = i+0;
+ maxidx[3] = i;
+ maxidx[2] = i;
+ maxidx[1] = i;
+ maxidx[0] = i;
break;
case 0:
default: ;
}
}
- maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1];
- maxidx[1] = (x[maxidx[2]] > x[maxidx[3]]) ? maxidx[2] : maxidx[3];
+ maxidx[0] = (x[maxidx[0]+0] > x[maxidx[1]+1]) ? maxidx[0]+0 : maxidx[1]+1;
+ maxidx[1] = (x[maxidx[2]+2] > x[maxidx[3]+3]) ? maxidx[2]+2 : maxidx[3]+3;
maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1];
return maxidx[0];
@@ -597,7 +601,7 @@ argmax_kernel8_avx2(int n, double *x)
int
argmax_kernel8(int n, double *x)
{
-#define SET(d) idx[d] = 0, max[d] = x[d]
+#define SET(d) idx[d] = d, max[d] = x[d]
#define PUT(d) if (x[i+d] > max[d]) idx[d] = i+d, max[d] = x[i+d]
int i, idx[8];
double max[8];
@@ -641,6 +645,171 @@ blas·argmax(int len, double *x)
return n;
}
+int
+argmin_kernel8_avx2(int n, double *x)
+{
+ register int i, msk, minidx[4];
+ __m256d val, cmp, min;
+
+ minidx[0] = 0, minidx[1] = 0, minidx[2] = 0, minidx[3] = 0;
+ min = _mm256_loadu_pd(x);
+
+ for (i = 0; i < n; i += 4) {
+ val = _mm256_loadu_pd(x+i);
+ cmp = _mm256_cmp_pd(val, min, _CMP_LT_OS);
+ msk = _mm256_movemask_pd(cmp);
+ switch (msk) {
+ case 1:
+ min = _mm256_blend_pd(min, val, 1);
+ minidx[0] = i;
+ break;
+
+ case 2:
+ min = _mm256_blend_pd(min, val, 2);
+ minidx[1] = i;
+ break;
+
+ case 3:
+ min = _mm256_blend_pd(min, val, 3);
+ minidx[0] = i;
+ minidx[1] = i;
+ break;
+
+ case 4:
+ min = _mm256_blend_pd(min, val, 4);
+ minidx[2] = i;
+ break;
+
+ case 5:
+ min = _mm256_blend_pd(min, val, 5);
+ minidx[2] = i;
+ minidx[0] = i;
+ break;
+
+ case 6:
+ min = _mm256_blend_pd(min, val, 6);
+ minidx[2] = i;
+ minidx[1] = i;
+ break;
+
+ case 7:
+ min = _mm256_blend_pd(min, val, 7);
+ minidx[2] = i;
+ minidx[1] = i;
+ minidx[0] = i;
+ break;
+
+ case 8:
+ min = _mm256_blend_pd(min, val, 8);
+ minidx[3] = i;
+ break;
+
+ case 9:
+ min = _mm256_blend_pd(min, val, 9);
+ minidx[3] = i;
+ minidx[0] = i;
+ break;
+
+ case 10:
+ min = _mm256_blend_pd(min, val, 10);
+ minidx[3] = i;
+ minidx[1] = i;
+ break;
+
+ case 11:
+ min = _mm256_blend_pd(min, val, 11);
+ minidx[3] = i;
+ minidx[1] = i;
+ minidx[0] = i;
+ break;
+
+ case 12:
+ min = _mm256_blend_pd(min, val, 12);
+ minidx[3] = i;
+ minidx[2] = i;
+ break;
+
+ case 13:
+ min = _mm256_blend_pd(min, val, 13);
+ minidx[3] = i;
+ minidx[2] = i;
+ minidx[0] = i;
+ break;
+
+ case 14:
+ min = _mm256_blend_pd(min, val, 14);
+ minidx[3] = i;
+ minidx[2] = i;
+ minidx[1] = i;
+ break;
+
+ case 15:
+ min = _mm256_blend_pd(min, val, 15);
+ minidx[3] = i;
+ minidx[2] = i;
+ minidx[1] = i;
+ minidx[0] = i;
+ break;
+
+ case 0:
+ default: ;
+ }
+ }
+ minidx[0] = (x[minidx[0]+0] < x[minidx[1]+1]) ? minidx[0]+0 : minidx[1]+1;
+ minidx[1] = (x[minidx[2]+2] < x[minidx[3]+3]) ? minidx[2]+2 : minidx[3]+3;
+ minidx[0] = (x[minidx[0]] < x[minidx[1]]) ? minidx[0] : minidx[1];
+
+ return minidx[0];
+}
+
+
+int
+argmin_kernel8(int n, double *x)
+{
+#define SET(d) idx[d] = d, min[d] = x[d]
+#define PUT(d) if (x[i+d] < min[d]) idx[d] = i+d, min[d] = x[i+d]
+ int i, idx[8];
+ double min[8];
+
+ SET(0); SET(1); SET(2); SET(3);
+ SET(4); SET(5); SET(6); SET(7);
+
+ for (i = 0; i < n; i += 8) {
+ PUT(0); PUT(1); PUT(2); PUT(3);
+ PUT(4); PUT(5); PUT(6); PUT(7);
+ }
+
+ n = 0;
+ for (i = 1; i < 8; i++) {
+ if (min[i] < min[n]) {
+ n = i;
+ }
+ }
+ return idx[n];
+#undef PUT
+#undef SET
+}
+
+int
+blas·argmin(int len, double *x)
+{
+ int i, n;
+ double min;
+
+ i = len & ~7;
+ n = argmin_kernel8_avx2(i, x);
+
+ min = x[n];
+ for (; i < len; i++) {
+ if (x[i] < min) {
+ n = i;
+ min = x[i];
+ }
+ }
+
+ return n;
+}
+
// -----------------------------------------------------------------------
// level two
@@ -650,8 +819,47 @@ blas·argmax(int len, double *x)
*/
+// NOTE: All triangular matrix methods are assumed packed and upper for now!
+
/*
- * Affine transformation
+ * triangular shaped transformation
+ * x = Mx
+ * @M: square triangular
+ * TODO(PERF): Diagnose speed issues
+ * TODO: Finish all other flag cases!
+ */
+void
+blas·tpmv(blas·Flag f, int n, double *m, double *x)
+{
+ int i;
+ for (i = 0; i < n; m += (n-i), ++x, ++i) {
+ *x = blas·dot(n-i, m, x);
+ }
+}
+
+/*
+ * solve triangular set of equations
+ * x = M^{-1}b
+ * @M: square triangular
+ * TODO(PERF): Diagnose speed issues
+ * TODO: Finish all other flag cases!
+ */
+void
+blas·tpsv(blas·Flag f, int n, double *m, double *x)
+{
+ int i;
+ double r;
+
+ x += (n - 1);
+ m += ((n * (n+1))/2 - 1);
+ for (i = n-1; i >= 0; --i, --x, m -= (n-i)) {
+ r = blas·dot(n-i-1, m+1, x+1);
+ *x = (*x - r) / *m;
+ }
+}
+
+/*
+ * general affine transformation
* y = aMx + by
*/
static
@@ -777,6 +985,17 @@ blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m)
}
/*
+ * rank one addition
+ * M = ax(x^T) + M
+ */
+
+void
+blas·her(int n, double a, double *x, double *m)
+{
+ blas·ger(n, n, a, x, x, m);
+}
+
+/*
* symmetric rank one addition
* M = ax(x^T) + M
* TODO: vectorize kernel
@@ -794,6 +1013,7 @@ blas·syr(int nrow, int ncol, double a, double *x, double *m)
}
}
+
// -----------------------------------------------------------------------
// level three
@@ -836,12 +1056,53 @@ blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, d
}
}
-#define NITER 30000
-#define NCOL 50007
-#define NROW 107
+/*
+ * triangular matrix multiplication
+ * m2 = a * m1 * m2 _OR_ a * m2 * m1
+ * m1 is assumed triangular
+ * einstein notation:
+ * m2_{ij} = a m1_{ik} m2_{kj} _OR_ a m1_{kj} m2_{ik}
+ *
+ * nrow = # rows of m2
+ * ncol = # cols of m2
+ * TODO(PERF): make compute kernel
+ * TODO: finish all other flags
+ */
+void
+blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
+{
+ int i, j, k, len;
+
+ for (i = 0; i < nrow; i++) {
+ for (j = 0; j < ncol; j++) {
+ for (k = i; k < ncol; k++) {
+ m2[i + ncol*j] += a * m1[i + nrow*k] * m2[k + ncol*j];
+ }
+ }
+ }
+}
+
+/*
+ * solve triangular matrix system of equations
+ * m2 = a * m1^{-1L} _OR_ a * m2 * m1
+ * m1 is assumed triangular
+ *
+ * nrow = # rows of m2
+ * ncol = # cols of m2
+ * TODO: complete stub
+ * TODO: finish all other flags
+ */
+void
+blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
+{
+}
+
+#define NITER 10000
+#define NCOL 5007
+#define NROW 5007
error
-test·gemm()
+test·level3()
{
int i, n;
clock_t t;
@@ -915,7 +1176,7 @@ print·array(int n, double *x)
}
error
-main()
+test·level1()
{
int ai, ai2, i, n;
double *x, *y;
@@ -932,24 +1193,27 @@ main()
for (n = 0; n < NITER; n++) {
for (i = 0; i < NCOL; i++) {
- y[n] = rng·random();
+ y[i] = rng·random();
}
memcpy(x, y, sizeof(*x)*NCOL);
t = clock();
- ai = blas·argmax(NCOL, x);
+ ai = blas·argmin(NCOL, x);
t = clock() - t;
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
+ if (n == 20729) {
+ printf("[%d]=%f vs [%d]=%f\n", 74202, x[74202], 3, x[3]);
+ }
memcpy(x, y, sizeof(*x)*NCOL);
t = clock();
- ai2 = cblas_idamax(NCOL, x, 1);
+ ai2 = cblas_idamin(NCOL, x, 1);
t = clock() - t;
tprof[1] += 1000.*t/CLOCKS_PER_SEC;
if (ai != ai2) {
- printf("iteration %d: %d not equal to %d\n", n, ai, ai2);
+ printf("iteration %d: %d not equal to %d. %f vs %f\n", n, ai, ai2, x[ai], x[ai2]);
}
}
@@ -965,4 +1229,54 @@ main()
a = 10.234, b = 2.;
blas·rotg(&a, &b, &c, &s);
printf("%f, %f, %f, %f\n", a, b, c, s);
+
+ return 0;
+}
+
+error
+main()
+{
+ int i, j, n, it;
+ clock_t t;
+
+ double *x, *y, *m;
+ double tprof[2];
+
+ rng·init(0);
+
+ tprof[0] = 0, tprof[1] = 0;
+ x = malloc(sizeof(*x)*NCOL);
+ y = malloc(sizeof(*x)*NCOL);
+ m = malloc(sizeof(*x)*NCOL*(NCOL+1)/2);
+
+ for (it = 0; it < NITER; it++) {
+ n = 0;
+ for (i = 0; i < NCOL; i++) {
+ y[i] = rng·random();
+ for (j = i; j < NCOL; j++) {
+ m[n++] = rng·random() + .1; // To ensure not singular
+ }
+ }
+
+ memcpy(x, y, NCOL * sizeof(*x));
+
+ t = clock();
+ blas·tpsv(0, NCOL, m, x);
+ t = clock() - t;
+ tprof[0] += 1000.*t/CLOCKS_PER_SEC;
+
+ t = clock();
+ cblas_dtpsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, NCOL, m, y, 1);
+ t = clock() - t;
+ tprof[1] += 1000.*t/CLOCKS_PER_SEC;
+
+ for (i = 0; i < NCOL; i++) {
+ if (math·abs(x[i] - y[i])/math·abs(x[i]) > 1e-5) {
+ errorf("failure at index %d: %f != %f", i, x[i], y[i]);
+ }
+ }
+ }
+
+ printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
+ printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
}
diff --git a/sys/libmath/linalg.c b/sys/libmath/linalg.c
new file mode 100644
index 0000000..57f799b
--- /dev/null
+++ b/sys/libmath/linalg.c
@@ -0,0 +1,5 @@
+#include <u.h>
+#include <libn.h>
+#include <libmath.h>
+
+