From d5e3041d34e4615ea8f81bd39a2a9231ef38253f Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Fri, 8 May 2020 11:47:17 -0700 Subject: Prototype of BLAS level 1 functions (double) Functions run at ~90% of the speed of tested OpenBLAS functions --- sys/libmath/blas.c | 321 ++++++++++++++++++++++++++++++++++++++++++++++++----- 1 file changed, 295 insertions(+), 26 deletions(-) (limited to 'sys') diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index 8343a42..f12e3e2 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -10,7 +10,7 @@ // level one /* - * Rotate vector + * rotate vector * x = cos*x + sin*y * y = cos*x - sin*y */ @@ -118,6 +118,9 @@ blas·rotg(double *a, double *b, double *cos, double *sin) * operates on len points * * params = [flag, h11, h12, h21, h22] + * NOTE: This is row major as opposed to other implementations + * + * Flags correspond to: * @flag = -1: * H -> [ [h11, h12], [h21, h22] ] * @flag = 0.0: @@ -126,12 +129,39 @@ blas·rotg(double *a, double *b, double *cos, double *sin) * H -> [ [h11, 1], [-1, h22] ] * @flag = -2: * H -> [ [1, 0], [0, 1] ] + * @flag = * + * return error * * Replaces: * x -> H11 * x + H12 * y * y -> H21 * x + H22 * y */ +static +void +rotm_kernel8_avx2(int n, double *x, double *y, double H[4]) +{ + register int i; + __m256d x256, y256; + __m256d H256[4]; + + for (i = 0; i < 4; i++) { + H256[i] = _mm256_broadcastsd_pd(_mm_load_sd(H+i)); + } + + for (i = 0; i < n; i+=8) { + x256 = _mm256_loadu_pd(x+i+0); + y256 = _mm256_loadu_pd(y+i+0); + _mm256_storeu_pd(x+i+0, H256[0] * x256 + H256[1] * y256); + _mm256_storeu_pd(y+i+0, H256[2] * y256 - H256[3] * x256); + + x256 = _mm256_loadu_pd(x+i+4); + y256 = _mm256_loadu_pd(y+i+4); + _mm256_storeu_pd(x+i+4, H256[0] * x256 + H256[1] * y256); + _mm256_storeu_pd(y+i+4, H256[2] * y256 - H256[3] * x256); + } +} + static void rotm_kernel8(int n, double *x, double *y, double H[4]) @@ -159,8 +189,8 @@ blas·rotm(int len, double *x, double *y, double p[5]) flag = math·round(p[0]); switch (flag) { - case 0: H[0] = p[1], H[1] = p[2], H[2] = p[3], H[3] = p[4]; break; - case -1: H[0] = +1, H[1] = p[2], H[2] = p[3], H[3] = +1; break; + case -1: H[0] = p[1], H[1] = p[2], H[2] = p[3], H[3] = p[4]; break; + case 0: H[0] = +1, H[1] = p[2], H[2] = p[3], H[3] = +1; break; case +1: H[0] = p[1], H[1] = +1, H[2] = -1, H[3] = p[4]; break; case -2: H[0] = +1, H[1] = 0, H[2] = 0, H[3] = +1; break; default: @@ -169,7 +199,7 @@ blas·rotm(int len, double *x, double *y, double p[5]) } n = len & ~7; - rotm_kernel8(n, x, y, H); + rotm_kernel8_avx2(n, x, y, H); for (; n < len; n++) { tmp = x[n], x[n] = H[0]*x[n] + H[1]*y[n], y[n] = H[2]*y[n] + H[3]*tmp; @@ -180,7 +210,7 @@ blas·rotm(int len, double *x, double *y, double p[5]) /* - * Scale vector + * scale vector * x = ax */ @@ -231,14 +261,78 @@ blas·scale(int len, double *x, double a) } /* - * Daxpy + * copy + * y = x + */ + +void +blas·copy(int len, double *x, double *y) +{ + memcpy(y, x, sizeof(*x) * len); +} + +/* + * swap + * y <=> x + */ + +void +swap_kernel8_avx2(int n, double *x, double *y) +{ + register int i; + __m256d tmp[2]; + for (i = 0; i < n; i += 8) { + tmp[0] = _mm256_loadu_pd(x+i+0); + tmp[1] = _mm256_loadu_pd(y+i+0); + _mm256_storeu_pd(x+i+0, tmp[1]); + _mm256_storeu_pd(y+i+0, tmp[0]); + + tmp[0] = _mm256_loadu_pd(x+i+4); + tmp[1] = _mm256_loadu_pd(y+i+4); + _mm256_storeu_pd(x+i+4, tmp[1]); + _mm256_storeu_pd(y+i+4, tmp[0]); + } +} + +void +swap_kernel8(int n, double *x, double *y) +{ + register int i; + register double tmp; + for (i = 0; i < n; i += 8) { + tmp = x[i+0], x[i+0] = y[i+0], y[i+0] = tmp; + tmp = x[i+1], x[i+1] = y[i+1], y[i+1] = tmp; + tmp = x[i+2], x[i+2] = y[i+2], y[i+2] = tmp; + tmp = x[i+3], x[i+3] = y[i+3], y[i+3] = tmp; + tmp = x[i+4], x[i+4] = y[i+4], y[i+4] = tmp; + tmp = x[i+5], x[i+5] = y[i+5], y[i+5] = tmp; + tmp = x[i+6], x[i+6] = y[i+6], y[i+6] = tmp; + tmp = x[i+7], x[i+7] = y[i+7], y[i+7] = tmp; + } +} + +void +blas·swap(int len, double *x, double *y) +{ + int n; + double tmp; + + n = len & ~7; + swap_kernel8(n, x, y); + for (; n < len; n++) { + tmp = x[n], x[n] = y[n], y[n] = tmp; + } +} + +/* + * daxpy * y = ax + y */ static void -daxpy_kernel8_avx2(int n, double a, double *x, double *y) +axpy_kernel8_avx2(int n, double a, double *x, double *y) { __m128d a128; __m256d a256; @@ -254,7 +348,7 @@ daxpy_kernel8_avx2(int n, double a, double *x, double *y) static void -daxpy_kernel8(int n, double a, double *x, double *y) +axpy_kernel8(int n, double a, double *x, double *y) { register int i; for (i = 0; i < n; i += 8) { @@ -270,22 +364,22 @@ daxpy_kernel8(int n, double a, double *x, double *y) } void -blas·daxpy(int len, double a, double *x, double *y) +blas·axpy(int len, double a, double *x, double *y) { int n; n = len & ~7; - daxpy_kernel8_avx2(n, a, x, y); + axpy_kernel8_avx2(n, a, x, y); for (; n < len; n++) { y[n] += a*x[n]; } } -/************************************************ - * Dot product +/* + * dot product * x·y - ***********************************************/ + */ static double @@ -378,6 +472,175 @@ blas·dot(int len, double *x, double *y) return res; } +/* + * argmax + * i = argmax(x) + */ + +int +argmax_kernel8_avx2(int n, double *x) +{ + register int i, msk, maxidx[4]; + __m256d val, cmp, max; + + maxidx[0] = 0, maxidx[1] = 1, maxidx[2] = 2, maxidx[3] = 3; + max = _mm256_loadu_pd(x); + + for (i = 0; i < n; i += 4) { + val = _mm256_loadu_pd(x+i); + cmp = _mm256_cmp_pd(val, max, _CMP_GT_OQ); + msk = _mm256_movemask_pd(cmp); + switch (msk) { + case 1: + max = _mm256_blend_pd(max, val, 1); + maxidx[0] = i+0; + break; + + case 2: + max = _mm256_blend_pd(max, val, 2); + maxidx[1] = i+1; + break; + + case 3: + max = _mm256_blend_pd(max, val, 3); + maxidx[0] = i+0; + maxidx[1] = i+1; + break; + + case 4: + max = _mm256_blend_pd(max, val, 4); + maxidx[2] = i+2; + break; + + case 5: + max = _mm256_blend_pd(max, val, 5); + maxidx[2] = i+2; + maxidx[0] = i+0; + break; + + case 6: + max = _mm256_blend_pd(max, val, 6); + maxidx[2] = i+2; + maxidx[1] = i+1; + break; + + case 7: + max = _mm256_blend_pd(max, val, 7); + maxidx[2] = i+2; + maxidx[1] = i+1; + maxidx[0] = i+0; + break; + + case 8: + max = _mm256_blend_pd(max, val, 8); + maxidx[3] = i+3; + break; + + case 9: + max = _mm256_blend_pd(max, val, 9); + maxidx[3] = i+3; + maxidx[0] = i+0; + break; + + case 10: + max = _mm256_blend_pd(max, val, 10); + maxidx[3] = i+3; + maxidx[1] = i+1; + break; + + case 11: + max = _mm256_blend_pd(max, val, 11); + maxidx[3] = i+3; + maxidx[1] = i+1; + maxidx[0] = i+0; + break; + + case 12: + max = _mm256_blend_pd(max, val, 12); + maxidx[3] = i+3; + maxidx[2] = i+2; + break; + + case 13: + max = _mm256_blend_pd(max, val, 13); + maxidx[3] = i+3; + maxidx[2] = i+2; + maxidx[0] = i+0; + break; + + case 14: + max = _mm256_blend_pd(max, val, 14); + maxidx[3] = i+3; + maxidx[2] = i+2; + maxidx[1] = i+1; + break; + + case 15: + max = _mm256_blend_pd(max, val, 15); + maxidx[3] = i+3; + maxidx[2] = i+2; + maxidx[1] = i+1; + maxidx[0] = i+0; + break; + + case 0: + default: ; + } + } + maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1]; + maxidx[1] = (x[maxidx[2]] > x[maxidx[3]]) ? maxidx[2] : maxidx[3]; + maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1]; + + return maxidx[0]; +} + +int +argmax_kernel8(int n, double *x) +{ +#define SET(d) idx[d] = 0, max[d] = x[d] +#define PUT(d) if (x[i+d] > max[d]) idx[d] = i+d, max[d] = x[i+d] + int i, idx[8]; + double max[8]; + + SET(0); SET(1); SET(2); SET(3); + SET(4); SET(5); SET(6); SET(7); + + for (i = 0; i < n; i += 8) { + PUT(0); PUT(1); PUT(2); PUT(3); + PUT(4); PUT(5); PUT(6); PUT(7); + } + + n = 0; + for (i = 1; i < 8; i++ ) { + if (max[i] > max[n]) { + n = i; + } + } + return idx[n]; +#undef PUT +#undef SET +} + +int +blas·argmax(int len, double *x) +{ + int i, n; + double max; + + i = len & ~7; + n = argmax_kernel8_avx2(i, x); + + max = x[n]; + for (; i < len; i++) { + if (x[i] > max) { + n = i; + max = x[i]; + } + } + + return n; +} + // ----------------------------------------------------------------------- // level two @@ -563,7 +826,7 @@ blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, d len = n1 & ~7; for (j = 0; j < n2; j++) { for (k = 0; k < n3; k++) { - daxpy_kernel8_avx2(len, a * m2[k + n2*j], m1 + n3*k, m3 + n2*j); + axpy_kernel8_avx2(len, a * m2[k + n2*j], m1 + n3*k, m3 + n2*j); // remainder for (i = len; i < n1; i++) { @@ -573,9 +836,9 @@ blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, d } } -#define NITER 5000 -#define NCOL 10007 -#define NROW 10007 +#define NITER 30000 +#define NCOL 50007 +#define NROW 107 error test·gemm() @@ -654,34 +917,40 @@ print·array(int n, double *x) error main() { - int i, n; + int ai, ai2, i, n; double *x, *y; double tprof[2]; + // double params[5]; clock_t t; x = malloc(sizeof(*x)*NCOL); y = malloc(sizeof(*x)*NCOL); + rng·init(0); + + // params[0] = -1.; + // params[1] = 100; params[2] = 20; params[3] = 30; params[4] = 10; for (n = 0; n < NITER; n++) { for (i = 0; i < NCOL; i++) { - x[i] = i*i+1; - y[i] = i+1; + y[n] = rng·random(); } + memcpy(x, y, sizeof(*x)*NCOL); t = clock(); - blas·rot(NCOL, x, y, .707, .707); + ai = blas·argmax(NCOL, x); t = clock() - t; tprof[0] += 1000.*t/CLOCKS_PER_SEC; - for (i = 0; i < NCOL; i++) { - x[i] = i*i+1; - y[i] = i+1; - } + memcpy(x, y, sizeof(*x)*NCOL); t = clock(); - cblas_drot(NCOL, x, 1, y, 1, .707, .707); + ai2 = cblas_idamax(NCOL, x, 1); t = clock() - t; tprof[1] += 1000.*t/CLOCKS_PER_SEC; + + if (ai != ai2) { + printf("iteration %d: %d not equal to %d\n", n, ai, ai2); + } } printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER); -- cgit v1.2.1