aboutsummaryrefslogtreecommitdiff
path: root/sys
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-08 11:47:17 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-08 11:47:17 -0700
commitd5e3041d34e4615ea8f81bd39a2a9231ef38253f (patch)
treec01e4b1df451e98357b9e2c62790543101b7b194 /sys
parent36117f59ec77784c9ef77801d7c1cbf03a4c4a8b (diff)
Prototype of BLAS level 1 functions (double)
Functions run at ~90% of the speed of tested OpenBLAS functions
Diffstat (limited to 'sys')
-rw-r--r--sys/libmath/blas.c321
1 files changed, 295 insertions, 26 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index 8343a42..f12e3e2 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -10,7 +10,7 @@
// level one
/*
- * Rotate vector
+ * rotate vector
* x = cos*x + sin*y
* y = cos*x - sin*y
*/
@@ -118,6 +118,9 @@ blas·rotg(double *a, double *b, double *cos, double *sin)
* operates on len points
*
* params = [flag, h11, h12, h21, h22]
+ * NOTE: This is row major as opposed to other implementations
+ *
+ * Flags correspond to:
* @flag = -1:
* H -> [ [h11, h12], [h21, h22] ]
* @flag = 0.0:
@@ -126,6 +129,8 @@ blas·rotg(double *a, double *b, double *cos, double *sin)
* H -> [ [h11, 1], [-1, h22] ]
* @flag = -2:
* H -> [ [1, 0], [0, 1] ]
+ * @flag = *
+ * return error
*
* Replaces:
* x -> H11 * x + H12 * y
@@ -134,6 +139,31 @@ blas·rotg(double *a, double *b, double *cos, double *sin)
static
void
+rotm_kernel8_avx2(int n, double *x, double *y, double H[4])
+{
+ register int i;
+ __m256d x256, y256;
+ __m256d H256[4];
+
+ for (i = 0; i < 4; i++) {
+ H256[i] = _mm256_broadcastsd_pd(_mm_load_sd(H+i));
+ }
+
+ for (i = 0; i < n; i+=8) {
+ x256 = _mm256_loadu_pd(x+i+0);
+ y256 = _mm256_loadu_pd(y+i+0);
+ _mm256_storeu_pd(x+i+0, H256[0] * x256 + H256[1] * y256);
+ _mm256_storeu_pd(y+i+0, H256[2] * y256 - H256[3] * x256);
+
+ x256 = _mm256_loadu_pd(x+i+4);
+ y256 = _mm256_loadu_pd(y+i+4);
+ _mm256_storeu_pd(x+i+4, H256[0] * x256 + H256[1] * y256);
+ _mm256_storeu_pd(y+i+4, H256[2] * y256 - H256[3] * x256);
+ }
+}
+
+static
+void
rotm_kernel8(int n, double *x, double *y, double H[4])
{
register int i;
@@ -159,8 +189,8 @@ blas·rotm(int len, double *x, double *y, double p[5])
flag = math·round(p[0]);
switch (flag) {
- case 0: H[0] = p[1], H[1] = p[2], H[2] = p[3], H[3] = p[4]; break;
- case -1: H[0] = +1, H[1] = p[2], H[2] = p[3], H[3] = +1; break;
+ case -1: H[0] = p[1], H[1] = p[2], H[2] = p[3], H[3] = p[4]; break;
+ case 0: H[0] = +1, H[1] = p[2], H[2] = p[3], H[3] = +1; break;
case +1: H[0] = p[1], H[1] = +1, H[2] = -1, H[3] = p[4]; break;
case -2: H[0] = +1, H[1] = 0, H[2] = 0, H[3] = +1; break;
default:
@@ -169,7 +199,7 @@ blas·rotm(int len, double *x, double *y, double p[5])
}
n = len & ~7;
- rotm_kernel8(n, x, y, H);
+ rotm_kernel8_avx2(n, x, y, H);
for (; n < len; n++) {
tmp = x[n], x[n] = H[0]*x[n] + H[1]*y[n], y[n] = H[2]*y[n] + H[3]*tmp;
@@ -180,7 +210,7 @@ blas·rotm(int len, double *x, double *y, double p[5])
/*
- * Scale vector
+ * scale vector
* x = ax
*/
@@ -231,14 +261,78 @@ blas·scale(int len, double *x, double a)
}
/*
- * Daxpy
+ * copy
+ * y = x
+ */
+
+void
+blas·copy(int len, double *x, double *y)
+{
+ memcpy(y, x, sizeof(*x) * len);
+}
+
+/*
+ * swap
+ * y <=> x
+ */
+
+void
+swap_kernel8_avx2(int n, double *x, double *y)
+{
+ register int i;
+ __m256d tmp[2];
+ for (i = 0; i < n; i += 8) {
+ tmp[0] = _mm256_loadu_pd(x+i+0);
+ tmp[1] = _mm256_loadu_pd(y+i+0);
+ _mm256_storeu_pd(x+i+0, tmp[1]);
+ _mm256_storeu_pd(y+i+0, tmp[0]);
+
+ tmp[0] = _mm256_loadu_pd(x+i+4);
+ tmp[1] = _mm256_loadu_pd(y+i+4);
+ _mm256_storeu_pd(x+i+4, tmp[1]);
+ _mm256_storeu_pd(y+i+4, tmp[0]);
+ }
+}
+
+void
+swap_kernel8(int n, double *x, double *y)
+{
+ register int i;
+ register double tmp;
+ for (i = 0; i < n; i += 8) {
+ tmp = x[i+0], x[i+0] = y[i+0], y[i+0] = tmp;
+ tmp = x[i+1], x[i+1] = y[i+1], y[i+1] = tmp;
+ tmp = x[i+2], x[i+2] = y[i+2], y[i+2] = tmp;
+ tmp = x[i+3], x[i+3] = y[i+3], y[i+3] = tmp;
+ tmp = x[i+4], x[i+4] = y[i+4], y[i+4] = tmp;
+ tmp = x[i+5], x[i+5] = y[i+5], y[i+5] = tmp;
+ tmp = x[i+6], x[i+6] = y[i+6], y[i+6] = tmp;
+ tmp = x[i+7], x[i+7] = y[i+7], y[i+7] = tmp;
+ }
+}
+
+void
+blas·swap(int len, double *x, double *y)
+{
+ int n;
+ double tmp;
+
+ n = len & ~7;
+ swap_kernel8(n, x, y);
+ for (; n < len; n++) {
+ tmp = x[n], x[n] = y[n], y[n] = tmp;
+ }
+}
+
+/*
+ * daxpy
* y = ax + y
*/
static
void
-daxpy_kernel8_avx2(int n, double a, double *x, double *y)
+axpy_kernel8_avx2(int n, double a, double *x, double *y)
{
__m128d a128;
__m256d a256;
@@ -254,7 +348,7 @@ daxpy_kernel8_avx2(int n, double a, double *x, double *y)
static
void
-daxpy_kernel8(int n, double a, double *x, double *y)
+axpy_kernel8(int n, double a, double *x, double *y)
{
register int i;
for (i = 0; i < n; i += 8) {
@@ -270,22 +364,22 @@ daxpy_kernel8(int n, double a, double *x, double *y)
}
void
-blas·daxpy(int len, double a, double *x, double *y)
+blas·axpy(int len, double a, double *x, double *y)
{
int n;
n = len & ~7;
- daxpy_kernel8_avx2(n, a, x, y);
+ axpy_kernel8_avx2(n, a, x, y);
for (; n < len; n++) {
y[n] += a*x[n];
}
}
-/************************************************
- * Dot product
+/*
+ * dot product
* x·y
- ***********************************************/
+ */
static
double
@@ -378,6 +472,175 @@ blas·dot(int len, double *x, double *y)
return res;
}
+/*
+ * argmax
+ * i = argmax(x)
+ */
+
+int
+argmax_kernel8_avx2(int n, double *x)
+{
+ register int i, msk, maxidx[4];
+ __m256d val, cmp, max;
+
+ maxidx[0] = 0, maxidx[1] = 1, maxidx[2] = 2, maxidx[3] = 3;
+ max = _mm256_loadu_pd(x);
+
+ for (i = 0; i < n; i += 4) {
+ val = _mm256_loadu_pd(x+i);
+ cmp = _mm256_cmp_pd(val, max, _CMP_GT_OQ);
+ msk = _mm256_movemask_pd(cmp);
+ switch (msk) {
+ case 1:
+ max = _mm256_blend_pd(max, val, 1);
+ maxidx[0] = i+0;
+ break;
+
+ case 2:
+ max = _mm256_blend_pd(max, val, 2);
+ maxidx[1] = i+1;
+ break;
+
+ case 3:
+ max = _mm256_blend_pd(max, val, 3);
+ maxidx[0] = i+0;
+ maxidx[1] = i+1;
+ break;
+
+ case 4:
+ max = _mm256_blend_pd(max, val, 4);
+ maxidx[2] = i+2;
+ break;
+
+ case 5:
+ max = _mm256_blend_pd(max, val, 5);
+ maxidx[2] = i+2;
+ maxidx[0] = i+0;
+ break;
+
+ case 6:
+ max = _mm256_blend_pd(max, val, 6);
+ maxidx[2] = i+2;
+ maxidx[1] = i+1;
+ break;
+
+ case 7:
+ max = _mm256_blend_pd(max, val, 7);
+ maxidx[2] = i+2;
+ maxidx[1] = i+1;
+ maxidx[0] = i+0;
+ break;
+
+ case 8:
+ max = _mm256_blend_pd(max, val, 8);
+ maxidx[3] = i+3;
+ break;
+
+ case 9:
+ max = _mm256_blend_pd(max, val, 9);
+ maxidx[3] = i+3;
+ maxidx[0] = i+0;
+ break;
+
+ case 10:
+ max = _mm256_blend_pd(max, val, 10);
+ maxidx[3] = i+3;
+ maxidx[1] = i+1;
+ break;
+
+ case 11:
+ max = _mm256_blend_pd(max, val, 11);
+ maxidx[3] = i+3;
+ maxidx[1] = i+1;
+ maxidx[0] = i+0;
+ break;
+
+ case 12:
+ max = _mm256_blend_pd(max, val, 12);
+ maxidx[3] = i+3;
+ maxidx[2] = i+2;
+ break;
+
+ case 13:
+ max = _mm256_blend_pd(max, val, 13);
+ maxidx[3] = i+3;
+ maxidx[2] = i+2;
+ maxidx[0] = i+0;
+ break;
+
+ case 14:
+ max = _mm256_blend_pd(max, val, 14);
+ maxidx[3] = i+3;
+ maxidx[2] = i+2;
+ maxidx[1] = i+1;
+ break;
+
+ case 15:
+ max = _mm256_blend_pd(max, val, 15);
+ maxidx[3] = i+3;
+ maxidx[2] = i+2;
+ maxidx[1] = i+1;
+ maxidx[0] = i+0;
+ break;
+
+ case 0:
+ default: ;
+ }
+ }
+ maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1];
+ maxidx[1] = (x[maxidx[2]] > x[maxidx[3]]) ? maxidx[2] : maxidx[3];
+ maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1];
+
+ return maxidx[0];
+}
+
+int
+argmax_kernel8(int n, double *x)
+{
+#define SET(d) idx[d] = 0, max[d] = x[d]
+#define PUT(d) if (x[i+d] > max[d]) idx[d] = i+d, max[d] = x[i+d]
+ int i, idx[8];
+ double max[8];
+
+ SET(0); SET(1); SET(2); SET(3);
+ SET(4); SET(5); SET(6); SET(7);
+
+ for (i = 0; i < n; i += 8) {
+ PUT(0); PUT(1); PUT(2); PUT(3);
+ PUT(4); PUT(5); PUT(6); PUT(7);
+ }
+
+ n = 0;
+ for (i = 1; i < 8; i++ ) {
+ if (max[i] > max[n]) {
+ n = i;
+ }
+ }
+ return idx[n];
+#undef PUT
+#undef SET
+}
+
+int
+blas·argmax(int len, double *x)
+{
+ int i, n;
+ double max;
+
+ i = len & ~7;
+ n = argmax_kernel8_avx2(i, x);
+
+ max = x[n];
+ for (; i < len; i++) {
+ if (x[i] > max) {
+ n = i;
+ max = x[i];
+ }
+ }
+
+ return n;
+}
+
// -----------------------------------------------------------------------
// level two
@@ -563,7 +826,7 @@ blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, d
len = n1 & ~7;
for (j = 0; j < n2; j++) {
for (k = 0; k < n3; k++) {
- daxpy_kernel8_avx2(len, a * m2[k + n2*j], m1 + n3*k, m3 + n2*j);
+ axpy_kernel8_avx2(len, a * m2[k + n2*j], m1 + n3*k, m3 + n2*j);
// remainder
for (i = len; i < n1; i++) {
@@ -573,9 +836,9 @@ blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, d
}
}
-#define NITER 5000
-#define NCOL 10007
-#define NROW 10007
+#define NITER 30000
+#define NCOL 50007
+#define NROW 107
error
test·gemm()
@@ -654,34 +917,40 @@ print·array(int n, double *x)
error
main()
{
- int i, n;
+ int ai, ai2, i, n;
double *x, *y;
double tprof[2];
+ // double params[5];
clock_t t;
x = malloc(sizeof(*x)*NCOL);
y = malloc(sizeof(*x)*NCOL);
+ rng·init(0);
+
+ // params[0] = -1.;
+ // params[1] = 100; params[2] = 20; params[3] = 30; params[4] = 10;
for (n = 0; n < NITER; n++) {
for (i = 0; i < NCOL; i++) {
- x[i] = i*i+1;
- y[i] = i+1;
+ y[n] = rng·random();
}
+ memcpy(x, y, sizeof(*x)*NCOL);
t = clock();
- blas·rot(NCOL, x, y, .707, .707);
+ ai = blas·argmax(NCOL, x);
t = clock() - t;
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
- for (i = 0; i < NCOL; i++) {
- x[i] = i*i+1;
- y[i] = i+1;
- }
+ memcpy(x, y, sizeof(*x)*NCOL);
t = clock();
- cblas_drot(NCOL, x, 1, y, 1, .707, .707);
+ ai2 = cblas_idamax(NCOL, x, 1);
t = clock() - t;
tprof[1] += 1000.*t/CLOCKS_PER_SEC;
+
+ if (ai != ai2) {
+ printf("iteration %d: %d not equal to %d\n", n, ai, ai2);
+ }
}
printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);