From 65fe4a1ddd852c9c702ae008c3b880a20b84d8e9 Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sun, 10 May 2020 14:22:50 -0700 Subject: begun work on making level 2 strided --- sys/libmath/blas.c | 187 ++++++++++++++++++++++++++++------------------------- 1 file changed, 98 insertions(+), 89 deletions(-) (limited to 'sys') diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index 18a43a6..63f8856 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -483,7 +483,7 @@ blas·axpy(int len, double a, double *x, int incx, double *y, int incy) static double -dot_kernel8_avx2(int len, double *x, double *y) +dot_kernel8_fma3(int len, double *x, double *y) { register int i; __m256d sum[4]; @@ -494,10 +494,10 @@ dot_kernel8_avx2(int len, double *x, double *y) } for (i = 0; i < len; i += 16) { - sum[0] += _mm256_loadu_pd(x+i+0) * _mm256_loadu_pd(y+i+0); - sum[1] += _mm256_loadu_pd(x+i+4) * _mm256_loadu_pd(y+i+4); - sum[2] += _mm256_loadu_pd(x+i+8) * _mm256_loadu_pd(y+i+8); - sum[3] += _mm256_loadu_pd(x+i+12) * _mm256_loadu_pd(y+i+12); + sum[0] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+0), _mm256_loadu_pd(y+i+0), sum[0]); + sum[1] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+4), _mm256_loadu_pd(y+i+4), sum[1]); + sum[2] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+8), _mm256_loadu_pd(y+i+8), sum[2]); + sum[3] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+12), _mm256_loadu_pd(y+i+12), sum[3]); } sum[0] += sum[1] + sum[2] + sum[3]; @@ -507,7 +507,7 @@ dot_kernel8_avx2(int len, double *x, double *y) static double -dot_kernel8_fma3(int len, double *x, double *y) +dot_kernel8_avx2(int len, double *x, double *y) { register int i; __m256d sum[4]; @@ -518,10 +518,10 @@ dot_kernel8_fma3(int len, double *x, double *y) } for (i = 0; i < len; i += 16) { - sum[0] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+0), _mm256_loadu_pd(y+i+0), sum[0]); - sum[1] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+4), _mm256_loadu_pd(y+i+4), sum[1]); - sum[2] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+8), _mm256_loadu_pd(y+i+8), sum[2]); - sum[3] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+12), _mm256_loadu_pd(y+i+12), sum[3]); + sum[0] += _mm256_loadu_pd(x+i+0) * _mm256_loadu_pd(y+i+0); + sum[1] += _mm256_loadu_pd(x+i+4) * _mm256_loadu_pd(y+i+4); + sum[2] += _mm256_loadu_pd(x+i+8) * _mm256_loadu_pd(y+i+8); + sum[3] += _mm256_loadu_pd(x+i+12) * _mm256_loadu_pd(y+i+12); } sum[0] += sum[1] + sum[2] + sum[3]; @@ -561,6 +561,7 @@ blas·dot(int len, double *x, int incx, double *y, int incy) if (incx == 1 && incy == 1) { n = EVEN_BY(len, 16); sum[0] = dot_kernel8_fma3(n, x, y); + x += n; y += n; } else { @@ -576,7 +577,7 @@ blas·dot(int len, double *x, int incx, double *y, int incy) } } - for (; i < len; i++, x += incx, y += incy) { + for (; n < len; n++, x += incx, y += incy) { sum[0] += x[0] * y[0]; } @@ -1259,6 +1260,10 @@ gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y) r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c); } + r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1)); + r128 = _mm_hadd_pd(r128, r128); + + *y = r128[0]; *y = hsum_avx2(r256); } @@ -1278,95 +1283,99 @@ gemv_1xN_kernel4(int n, double *row, double *x, double *y) } error -blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) +blas·gemv(int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy) { int c, r, nr, nc; double *row[8], res[8]; + enum { + err·nil, + err·incm, + }; - nc = EVEN_BY(ncol, 4); - - nr = EVEN_BY(nrow, 8); - for (r = 0; r < nr; r += 8) { - /* assumes row major layout */ - row[0] = m + ((r+0) * ncol); - row[1] = m + ((r+1) * ncol); - row[2] = m + ((r+2) * ncol); - row[3] = m + ((r+3) * ncol); - row[4] = m + ((r+4) * ncol); - row[5] = m + ((r+5) * ncol); - row[6] = m + ((r+6) * ncol); - row[7] = m + ((r+7) * ncol); - - gemv_8xN_kernel4_avx2(nc, row, x, res); - - for (c = nc; c < ncol; c++) { - res[0] += row[0][c]*x[c]; - res[1] += row[1][c]*x[c]; - res[2] += row[2][c]*x[c]; - res[3] += row[3][c]*x[c]; - res[4] += row[4][c]*x[c]; - res[5] += row[5][c]*x[c]; - res[6] += row[6][c]*x[c]; - res[7] += row[7][c]*x[c]; - } + if (incm < ncol) { + errorf("aliased matrix: inc = %d < ncols = %d", incm, ncol); + return err·incm; + } - y[r+0] = a*res[0] + b*y[r+0]; - y[r+1] = a*res[1] + b*y[r+1]; - y[r+2] = a*res[2] + b*y[r+2]; - y[r+3] = a*res[3] + b*y[r+3]; - y[r+4] = a*res[4] + b*y[r+4]; - y[r+5] = a*res[5] + b*y[r+5]; - y[r+6] = a*res[6] + b*y[r+6]; - y[r+7] = a*res[7] + b*y[r+7]; - } - - nr = EVEN_BY(nrow, 4); - for (; r < nr; r += 4) { - /* assumes row major layout */ - row[0] = m + ((r+0) * ncol); - row[1] = m + ((r+1) * ncol); - row[2] = m + ((r+2) * ncol); - row[3] = m + ((r+3) * ncol); - - gemv_4xN_kernel4_avx2(nc, row, x, res); - - for (c = nc; c < ncol; c++) { - res[0] += row[0][c]*x[c]; - res[1] += row[1][c]*x[c]; - res[2] += row[2][c]*x[c]; - res[3] += row[3][c]*x[c]; + if (incx == 1 && incy == 1) { + nc = EVEN_BY(ncol, 4); + + nr = EVEN_BY(nrow, 8); + for (r = 0; r < nr; r += 8) { + row[0] = m + ((r+0) * incm); + row[1] = m + ((r+1) * incm); + row[2] = m + ((r+2) * incm); + row[3] = m + ((r+3) * incm); + row[4] = m + ((r+4) * incm); + row[5] = m + ((r+5) * incm); + row[6] = m + ((r+6) * incm); + row[7] = m + ((r+7) * incm); + + gemv_8xN_kernel4_avx2(nc, row, x, res); + + for (c = nc; c < ncol; c++) { + res[0] += row[0][c]*x[c]; + res[1] += row[1][c]*x[c]; + res[2] += row[2][c]*x[c]; + res[3] += row[3][c]*x[c]; + res[4] += row[4][c]*x[c]; + res[5] += row[5][c]*x[c]; + res[6] += row[6][c]*x[c]; + res[7] += row[7][c]*x[c]; + } + + y[r+0] = a*res[0] + b*y[r+0]; + y[r+1] = a*res[1] + b*y[r+1]; + y[r+2] = a*res[2] + b*y[r+2]; + y[r+3] = a*res[3] + b*y[r+3]; + y[r+4] = a*res[4] + b*y[r+4]; + y[r+5] = a*res[5] + b*y[r+5]; + y[r+6] = a*res[6] + b*y[r+6]; + y[r+7] = a*res[7] + b*y[r+7]; } - y[r+0] = a*res[0] + b*y[r+0]; - y[r+1] = a*res[1] + b*y[r+1]; - y[r+2] = a*res[2] + b*y[r+2]; - y[r+3] = a*res[3] + b*y[r+3]; - } + nr = EVEN_BY(nrow, 4); + for (; r < nr; r += 4) { + row[0] = m + ((r+0) * incm); + row[1] = m + ((r+1) * incm); + row[2] = m + ((r+2) * incm); + row[3] = m + ((r+3) * incm); + + gemv_4xN_kernel4_avx2(nc, row, x, res); - nr = EVEN_BY(nrow, 2); - for (; r < nr; r += 2) { - row[0] = m + ((r+0) * ncol); - row[1] = m + ((r+1) * ncol); - gemv_2xN_kernel4_avx2(nc, row, x, res); + for (c = nc; c < ncol; c++) { + res[0] += row[0][c]*x[c]; + res[1] += row[1][c]*x[c]; + res[2] += row[2][c]*x[c]; + res[3] += row[3][c]*x[c]; + } - for (c = nc; c < ncol; c++) { - res[0] += row[0][c]*x[c]; - res[1] += row[1][c]*x[c]; + y[r+0] = a*res[0] + b*y[r+0]; + y[r+1] = a*res[1] + b*y[r+1]; + y[r+2] = a*res[2] + b*y[r+2]; + y[r+3] = a*res[3] + b*y[r+3]; } - y[r+0] = a*res[0] + b*y[r+0]; - y[r+1] = a*res[1] + b*y[r+1]; - } + nr = EVEN_BY(nrow, 2); + for (; r < nr; r += 2) { + row[0] = m + ((r+0) * incm); + row[1] = m + ((r+1) * incm); + gemv_2xN_kernel4_avx2(nc, row, x, res); - for (; r < nrow; r++) { - row[0] = m + ((r+0) * ncol); - gemv_1xN_kernel4_avx2(nc, row[0], x, res); + for (c = nc; c < ncol; c++) { + res[0] += row[0][c]*x[c]; + res[1] += row[1][c]*x[c]; + } - for (c = nc; c < ncol; c++) { - res[0] += row[0][c]*x[c]; + y[r+0] = a*res[0] + b*y[r+0]; + y[r+1] = a*res[1] + b*y[r+1]; } - y[r] = a*res[0] + b*y[r]; + for (; r < nrow; r++) { + row[0] = m + ((r+0) * ncol); + res[0] = blas·dot(ncol, row[0], 1, x, 1); + y[r] = a*res[0] + b*y[r]; + } } return 0; @@ -1504,8 +1513,8 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2) } #define NITER 1000 -#define NCOL 1005 -#define NROW 1005 +#define NCOL 2005 +#define NROW 2005 error test·level3() @@ -1602,13 +1611,13 @@ test·level2() memcpy(z, y, NCOL * sizeof(*y)); t = clock(); - blas·gemv(NROW, NCOL, 2, m, x, 0.0, y); + blas·gemv(NROW, NCOL, 2, m, NCOL, x, 1, 1.0, y, 1); t = clock() - t; tprof[0] += 1000.*t/CLOCKS_PER_SEC; t = clock(); - cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 0.0, z, 1); + cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 1.0, z, 1); t = clock() - t; tprof[1] += 1000.*t/CLOCKS_PER_SEC; -- cgit v1.2.1