From eaa498806479e30a6a825afaff63e2a1fe5702f9 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sat, 9 May 2020 16:35:04 -0700 Subject: added helper function for horizontal sums --- sys/libmath/blas.c | 187 +++++++++++++++++++++++++++++++++++++++++++++-------- 1 file changed, 160 insertions(+), 27 deletions(-) (limited to 'sys/libmath/blas.c') diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index 7056c08..18a43a6 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -8,6 +8,26 @@ #define EVEN_BY(x, n) (x) & ~((n)-1) +// ----------------------------------------------------------------------- +// misc functions + +static +inline +double +hsum_avx2(__m256d x) +{ + __m128d lo128, hi128, hi64; + + lo128 = _mm256_castpd256_pd128(x); + hi128 = _mm256_extractf128_pd(x, 1); + lo128 = _mm_add_pd(lo128, hi128); + + hi64 = _mm_unpackhi_pd(lo128, lo128); + + return _mm_cvtsd_f64(_mm_add_sd(lo128, hi64)); +} + + // ----------------------------------------------------------------------- // level one @@ -482,10 +502,7 @@ dot_kernel8_avx2(int len, double *x, double *y) sum[0] += sum[1] + sum[2] + sum[3]; - res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1)); - res = _mm_hadd_pd(res, res); - - return res[0]; + return hsum_avx2(sum[0]); } static @@ -509,10 +526,7 @@ dot_kernel8_fma3(int len, double *x, double *y) sum[0] += sum[1] + sum[2] + sum[3]; - res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1)); - res = _mm_hadd_pd(res, res); - - return res[0]; + return hsum_avx2(sum[0]); } static @@ -604,10 +618,7 @@ sum_kernel8_avx2(int len, double *x) sum[0] += sum[1]; - res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1)); - res = _mm_hadd_pd(res, res); - - return res[0]; + return hsum_avx2(sum[0]); } static @@ -1108,9 +1119,44 @@ blas·tpsv(blas·Flag f, int n, double *m, double *x) * general affine transformation * y = aMx + by */ + +static +void +gemv_8xN_kernel4_avx2(int ncol, double *row[8], double *x, double *y) +{ + int c; + __m128d hr; + __m256d x256, r256[8]; + + for (c = 0; c < 8; c++) { + r256[c] = _mm256_setzero_pd(); + } + + for (c = 0; c < ncol; c += 4) { + x256 = _mm256_loadu_pd(x+c); + r256[0] += x256 * _mm256_loadu_pd(row[0] + c); + r256[1] += x256 * _mm256_loadu_pd(row[1] + c); + r256[2] += x256 * _mm256_loadu_pd(row[2] + c); + r256[3] += x256 * _mm256_loadu_pd(row[3] + c); + r256[4] += x256 * _mm256_loadu_pd(row[4] + c); + r256[5] += x256 * _mm256_loadu_pd(row[5] + c); + r256[6] += x256 * _mm256_loadu_pd(row[6] + c); + r256[7] += x256 * _mm256_loadu_pd(row[7] + c); + } + + y[0] = hsum_avx2(r256[0]); + y[1] = hsum_avx2(r256[1]); + y[2] = hsum_avx2(r256[2]); + y[3] = hsum_avx2(r256[3]); + y[4] = hsum_avx2(r256[4]); + y[5] = hsum_avx2(r256[5]); + y[6] = hsum_avx2(r256[6]); + y[7] = hsum_avx2(r256[7]); +} + static void -gemv_4xN_kernel4_avx2(int ncol, double **row, double *x, double *y) +gemv_4xN_kernel4_avx2(int ncol, double *row[4], double *x, double *y) { int c; __m128d hr; @@ -1128,16 +1174,15 @@ gemv_4xN_kernel4_avx2(int ncol, double **row, double *x, double *y) r256[3] += x256 * _mm256_loadu_pd(row[3] + c); } - for (c = 0; c < 4; c++) { - hr = _mm_add_pd(_mm256_extractf128_pd(r256[c], 0), _mm256_extractf128_pd(r256[c], 1)); - hr = _mm_hadd_pd(hr, hr); - y[c] = hr[0]; - } + y[0] = hsum_avx2(r256[0]); + y[1] = hsum_avx2(r256[1]); + y[2] = hsum_avx2(r256[2]); + y[3] = hsum_avx2(r256[3]); } static void -gemv_4xN_kernel4(int n, double **row, double *x, double *y) +gemv_4xN_kernel4(int n, double *row[4], double *x, double *y) { int c; double res[4]; @@ -1160,6 +1205,47 @@ gemv_4xN_kernel4(int n, double **row, double *x, double *y) y[3] = res[3]; } +static +void +gemv_2xN_kernel4_avx2(int n, double *row[2], double *x, double *y) +{ + int c; + __m128d hr; + __m256d x256, r256[2]; + + for (c = 0; c < 2; c++) { + r256[c] = _mm256_setzero_pd(); + } + + for (c = 0; c < n; c += 4) { + x256 = _mm256_loadu_pd(x+c); + r256[0] += x256 * _mm256_loadu_pd(row[0] + c); + r256[1] += x256 * _mm256_loadu_pd(row[1] + c); + } + + y[0] = hsum_avx2(r256[0]); + y[1] = hsum_avx2(r256[1]); +} + +static +void +gemv_2xN_kernel4(int n, double *row[2], double *x, double *y) +{ + int c; + double res[2]; + + res[0] = 0.0; + res[1] = 0.0; + + for (c = 0; c < n; c += 4) { + res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3]; + res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3]; + } + + y[0] = res[0]; + y[1] = res[1]; +} + static void gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y) @@ -1173,10 +1259,7 @@ gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y) r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c); } - r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1)); - r128 = _mm_hadd_pd(r128, r128); - - *y = r128[0]; + *y = hsum_avx2(r256); } static @@ -1198,12 +1281,47 @@ error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) { int c, r, nr, nc; - double *row[4], res[4]; + double *row[8], res[8]; - nr = EVEN_BY(nrow, 4); nc = EVEN_BY(ncol, 4); - for (r = 0; r < nr; r += 4) { + nr = EVEN_BY(nrow, 8); + for (r = 0; r < nr; r += 8) { + /* assumes row major layout */ + row[0] = m + ((r+0) * ncol); + row[1] = m + ((r+1) * ncol); + row[2] = m + ((r+2) * ncol); + row[3] = m + ((r+3) * ncol); + row[4] = m + ((r+4) * ncol); + row[5] = m + ((r+5) * ncol); + row[6] = m + ((r+6) * ncol); + row[7] = m + ((r+7) * ncol); + + gemv_8xN_kernel4_avx2(nc, row, x, res); + + for (c = nc; c < ncol; c++) { + res[0] += row[0][c]*x[c]; + res[1] += row[1][c]*x[c]; + res[2] += row[2][c]*x[c]; + res[3] += row[3][c]*x[c]; + res[4] += row[4][c]*x[c]; + res[5] += row[5][c]*x[c]; + res[6] += row[6][c]*x[c]; + res[7] += row[7][c]*x[c]; + } + + y[r+0] = a*res[0] + b*y[r+0]; + y[r+1] = a*res[1] + b*y[r+1]; + y[r+2] = a*res[2] + b*y[r+2]; + y[r+3] = a*res[3] + b*y[r+3]; + y[r+4] = a*res[4] + b*y[r+4]; + y[r+5] = a*res[5] + b*y[r+5]; + y[r+6] = a*res[6] + b*y[r+6]; + y[r+7] = a*res[7] + b*y[r+7]; + } + + nr = EVEN_BY(nrow, 4); + for (; r < nr; r += 4) { /* assumes row major layout */ row[0] = m + ((r+0) * ncol); row[1] = m + ((r+1) * ncol); @@ -1225,8 +1343,23 @@ blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double y[r+3] = a*res[3] + b*y[r+3]; } + nr = EVEN_BY(nrow, 2); + for (; r < nr; r += 2) { + row[0] = m + ((r+0) * ncol); + row[1] = m + ((r+1) * ncol); + gemv_2xN_kernel4_avx2(nc, row, x, res); + + for (c = nc; c < ncol; c++) { + res[0] += row[0][c]*x[c]; + res[1] += row[1][c]*x[c]; + } + + y[r+0] = a*res[0] + b*y[r+0]; + y[r+1] = a*res[1] + b*y[r+1]; + } + for (; r < nrow; r++) { - row[0] = m + (r * ncol); + row[0] = m + ((r+0) * ncol); gemv_1xN_kernel4_avx2(nc, row[0], x, res); for (c = nc; c < ncol; c++) { -- cgit v1.2.1