From 3290956dc738abf4998adb94a96a3eff487fae3c Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sat, 9 May 2020 12:39:53 -0700 Subject: fix: mathematical bug in rotm. also generalized to allow for non unity increments --- sys/libmath/blas.c | 69 +++++++++++++++++++++++++++++++++++++----------------- 1 file changed, 47 insertions(+), 22 deletions(-) (limited to 'sys/libmath/blas.c') diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index f1ae09d..c419c6f 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -63,17 +63,30 @@ rot_kernel8(int n, double *x, double *y, double cos, double sin) } void -blas·rot(int len, double *x, double *y, double cos, double sin) +blas·rot(int len, double *x, int incx, double *y, int incy, double cos, double sin) { - register int n; + register int i, n; register double tmp; - n = len & ~7; - rot_kernel8_avx2(n, x, y, cos, sin); + if (incx == 1 && incy == 1) { + n = EVEN_BY(len, 8); + rot_kernel8_avx2(n, x, y, cos, sin); + x += n; + y += n; + } else { + n = EVEN_BY(len, 4); + for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) { + tmp = x[0*incx], x[0*incx] = cos*x[0*incx] + sin*y[0*incy], y[0*incy] = cos*y[0*incy] - sin*tmp; + tmp = x[1*incx], x[1*incx] = cos*x[1*incx] + sin*y[1*incy], y[1*incy] = cos*y[1*incy] - sin*tmp; + tmp = x[2*incx], x[2*incx] = cos*x[2*incx] + sin*y[2*incy], y[2*incy] = cos*y[2*incy] - sin*tmp; + tmp = x[3*incx], x[3*incx] = cos*x[3*incx] + sin*y[3*incy], y[3*incy] = cos*y[3*incy] - sin*tmp; + } + } - for (; n < len; n++) { - tmp = x[n], x[n] = cos*x[n] + sin*y[n], y[n] = cos*y[n]- sin*tmp; + for (; n < len; n++, x += incx, y += incy) { + tmp = x[0], x[0] = cos*x[0] + sin*y[0], y[0] = cos*y[0] - sin*tmp; } + } /* @@ -153,12 +166,12 @@ rotm_kernel8_avx2(int n, double *x, double *y, double H[4]) x256 = _mm256_loadu_pd(x+i+0); y256 = _mm256_loadu_pd(y+i+0); _mm256_storeu_pd(x+i+0, H256[0] * x256 + H256[1] * y256); - _mm256_storeu_pd(y+i+0, H256[2] * y256 - H256[3] * x256); + _mm256_storeu_pd(y+i+0, H256[2] * x256 + H256[3] * y256); x256 = _mm256_loadu_pd(x+i+4); y256 = _mm256_loadu_pd(y+i+4); _mm256_storeu_pd(x+i+4, H256[0] * x256 + H256[1] * y256); - _mm256_storeu_pd(y+i+4, H256[2] * y256 - H256[3] * x256); + _mm256_storeu_pd(y+i+4, H256[2] * x256 + H256[3] * y256); } } @@ -170,21 +183,21 @@ rotm_kernel8(int n, double *x, double *y, double H[4]) register double tmp; for (i = 0; i < n; i+=8) { - tmp = x[i+0], x[i+0] = H[0]*x[i+0] + H[1]*y[i+0], y[i+0] = H[2]*y[i+0] + H[3]*tmp; - tmp = x[i+1], x[i+1] = H[0]*x[i+1] + H[1]*y[i+1], y[i+1] = H[2]*y[i+1] + H[3]*tmp; - tmp = x[i+2], x[i+2] = H[0]*x[i+2] + H[1]*y[i+2], y[i+2] = H[2]*y[i+2] + H[3]*tmp; - tmp = x[i+3], x[i+3] = H[0]*x[i+3] + H[1]*y[i+3], y[i+3] = H[2]*y[i+3] + H[3]*tmp; - tmp = x[i+4], x[i+4] = H[0]*x[i+4] + H[1]*y[i+4], y[i+4] = H[2]*y[i+4] + H[3]*tmp; - tmp = x[i+5], x[i+5] = H[0]*x[i+5] + H[1]*y[i+5], y[i+5] = H[2]*y[i+5] + H[3]*tmp; - tmp = x[i+6], x[i+6] = H[0]*x[i+6] + H[1]*y[i+6], y[i+6] = H[2]*y[i+6] + H[3]*tmp; - tmp = x[i+7], x[i+7] = H[0]*x[i+7] + H[1]*y[i+7], y[i+7] = H[2]*y[i+7] + H[3]*tmp; + tmp = x[i+0], x[i+0] = H[0]*x[i+0] + H[1]*y[i+0], y[i+0] = H[2]*tmp + H[3]*y[i+0]; + tmp = x[i+1], x[i+1] = H[0]*x[i+1] + H[1]*y[i+1], y[i+1] = H[2]*tmp + H[3]*y[i+1]; + tmp = x[i+2], x[i+2] = H[0]*x[i+2] + H[1]*y[i+2], y[i+2] = H[2]*tmp + H[3]*y[i+2]; + tmp = x[i+3], x[i+3] = H[0]*x[i+3] + H[1]*y[i+3], y[i+3] = H[2]*tmp + H[3]*y[i+3]; + tmp = x[i+4], x[i+4] = H[0]*x[i+4] + H[1]*y[i+4], y[i+4] = H[2]*tmp + H[3]*y[i+4]; + tmp = x[i+5], x[i+5] = H[0]*x[i+5] + H[1]*y[i+5], y[i+5] = H[2]*tmp + H[3]*y[i+5]; + tmp = x[i+6], x[i+6] = H[0]*x[i+6] + H[1]*y[i+6], y[i+6] = H[2]*tmp + H[3]*y[i+6]; + tmp = x[i+7], x[i+7] = H[0]*x[i+7] + H[1]*y[i+7], y[i+7] = H[2]*tmp + H[3]*y[i+7]; } } error -blas·rotm(int len, double *x, double *y, double p[5]) +blas·rotm(int len, double *x, int incx, double *y, int incy, double p[5]) { - int n, flag; + int i, n, flag; double tmp, H[4]; flag = math·round(p[0]); @@ -198,11 +211,23 @@ blas·rotm(int len, double *x, double *y, double p[5]) return 1; } - n = len & ~7; - rotm_kernel8_avx2(n, x, y, H); + if (incx == 1 && incy == 1) { + n = EVEN_BY(len, 8); + rotm_kernel8_avx2(n, x, y, H); + x += n; + y += n; + } else { + n = EVEN_BY(len, 4); + for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) { + tmp = x[0*incx], x[0*incx] = H[0]*x[0*incx] + H[1]*y[0*incy], y[0*incy] = H[2]*y[0*incy] + H[3]*tmp; + tmp = x[1*incx], x[1*incx] = H[0]*x[1*incx] + H[1]*y[1*incy], y[1*incy] = H[2]*y[1*incy] + H[3]*tmp; + tmp = x[2*incx], x[2*incx] = H[0]*x[2*incx] + H[1]*y[2*incy], y[2*incy] = H[2]*y[2*incy] + H[3]*tmp; + tmp = x[3*incx], x[3*incx] = H[0]*x[3*incx] + H[1]*y[3*incy], y[3*incy] = H[2]*y[3*incy] + H[3]*tmp; + } + } - for (; n < len; n++) { - tmp = x[n], x[n] = H[0]*x[n] + H[1]*y[n], y[n] = H[2]*y[n] + H[3]*tmp; + for (; n < len; n++, x += incx, y += incy) { + tmp = x[0], x[0] = H[0]*x[0] + H[1]*y[0], y[0] = H[2]*tmp + H[3]*y[0]; } return 0; -- cgit v1.2.1