From 3be8d6291c95bada9576963770bb9c988708ecda Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sat, 9 May 2020 15:01:34 -0700 Subject: fix: indexing bug associated to columns in gemv function --- include/libmath.h | 2 +- sys/libmath/blas.c | 128 ++++++++++++++++++++++++++++++++++------------------- 2 files changed, 83 insertions(+), 47 deletions(-) diff --git a/include/libmath.h b/include/libmath.h index ecce28e..b148065 100644 --- a/include/libmath.h +++ b/include/libmath.h @@ -160,8 +160,8 @@ int blas·argmin(int len, double *x, int inc); /* level 2 */ void blas·tpmv(blas·Flag f, int n, double *m, double *x); -void blas·tpsv(blas·Flag f, int n, double *m, double *x); error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ; +void blas·tpsv(blas·Flag f, int n, double *m, double *x); void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m); void blas·her(int n, double a, double *x, double *m); void blas·syr(int nrow, int ncol, double a, double *x, double *m); diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index 1e1e1c7..7056c08 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -1110,7 +1110,7 @@ blas·tpsv(blas·Flag f, int n, double *m, double *x) */ static void -gemv_kernel4xN_4_avx2(int ncol, double **row, double *x, double *y) +gemv_4xN_kernel4_avx2(int ncol, double **row, double *x, double *y) { int c; __m128d hr; @@ -1137,17 +1137,17 @@ gemv_kernel4xN_4_avx2(int ncol, double **row, double *x, double *y) static void -gemv_kernel4xN_4(int ncol, double **row, double *x, double *y) +gemv_4xN_kernel4(int n, double **row, double *x, double *y) { int c; double res[4]; - res[0] = 0.; - res[1] = 0.; - res[2] = 0.; - res[3] = 0.; + res[0] = 0.0; + res[1] = 0.0; + res[2] = 0.0; + res[3] = 0.0; - for (c = 0; c < ncol; c += 4) { + for (c = 0; c < n; c += 4) { res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3]; res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3]; res[2] += row[2][c+0]*x[c+0] + row[2][c+1]*x[c+1] + row[2][c+2]*x[c+2] + row[2][c+3]*x[c+3]; @@ -1162,17 +1162,36 @@ gemv_kernel4xN_4(int ncol, double **row, double *x, double *y) static void -gemv_kernel1xN_4(int ncol, double *row, double *x, double *y) +gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y) +{ + int c; + __m128d r128; + __m256d r256; + + r256 = _mm256_setzero_pd(); + for (c = 0; c < n; c += 4) { + r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c); + } + + r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1)); + r128 = _mm_hadd_pd(r128, r128); + + *y = r128[0]; +} + +static +void +gemv_1xN_kernel4(int n, double *row, double *x, double *y) { int c; double res; res = 0.; - for (c = 0; c < ncol; c += 4) { + for (c = 0; c < n; c += 4) { res += row[c+0]*x[c+0] + row[c+1]*x[c+1] + row[c+2]*x[c+2] + row[c+3]*x[c+3]; } - y[0] = res; + *y = res; } error @@ -1181,32 +1200,40 @@ blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double int c, r, nr, nc; double *row[4], res[4]; - nr = nrow & ~3; - nc = ncol & ~3; - for (r = 0; r < nrow; r += 4) { - row[0] = m + (r * (ncol+0)); - row[1] = m + (r * (ncol+1)); - row[2] = m + (r * (ncol+2)); - row[3] = m + (r * (ncol+3)); + nr = EVEN_BY(nrow, 4); + nc = EVEN_BY(ncol, 4); + + for (r = 0; r < nr; r += 4) { + /* assumes row major layout */ + row[0] = m + ((r+0) * ncol); + row[1] = m + ((r+1) * ncol); + row[2] = m + ((r+2) * ncol); + row[3] = m + ((r+3) * ncol); - gemv_kernel4xN_4_avx2(ncol, row, x + r, res); + gemv_4xN_kernel4_avx2(nc, row, x, res); for (c = nc; c < ncol; c++) { - res[0] += row[0][c]; - res[1] += row[1][c]; - res[2] += row[2][c]; - res[3] += row[3][c]; + res[0] += row[0][c]*x[c]; + res[1] += row[1][c]*x[c]; + res[2] += row[2][c]*x[c]; + res[3] += row[3][c]*x[c]; } - y[r+0] = res[0] + b*y[r+0]; - y[r+1] = res[1] + b*y[r+1]; - y[r+2] = res[2] + b*y[r+2]; - y[r+3] = res[3] + b*y[r+3]; + y[r+0] = a*res[0] + b*y[r+0]; + y[r+1] = a*res[1] + b*y[r+1]; + y[r+2] = a*res[2] + b*y[r+2]; + y[r+3] = a*res[3] + b*y[r+3]; } for (; r < nrow; r++) { - gemv_kernel1xN_4(nrow, m + (r * ncol), x + r, res); - y[r] = res[0] + b*y[r]; + row[0] = m + (r * ncol); + gemv_1xN_kernel4_avx2(nc, row[0], x, res); + + for (c = nc; c < ncol; c++) { + res[0] += row[0][c]*x[c]; + } + + y[r] = a*res[0] + b*y[r]; } return 0; @@ -1343,9 +1370,9 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2) { } -#define NITER 3000 -#define NCOL 5000007 -#define NROW 57 +#define NITER 1000 +#define NCOL 1005 +#define NROW 1005 error test·level3() @@ -1416,40 +1443,46 @@ test·level2() int i, j, n, it; clock_t t; - double *x, *y, *m; + double *x, *y, *z, *m; double tprof[2]; rng·init(0); - tprof[0] = 0, tprof[1] = 0; + tprof[0] = 0; + tprof[1] = 0; + x = malloc(sizeof(*x)*NCOL); y = malloc(sizeof(*x)*NCOL); - m = malloc(sizeof(*x)*NCOL*(NCOL+1)/2); + z = malloc(sizeof(*x)*NCOL); + m = malloc(sizeof(*x)*NROW*NCOL); for (it = 0; it < NITER; it++) { n = 0; - for (i = 0; i < NCOL; i++) { + for (i = 0; i < NROW; i++) { + x[i] = rng·random(); y[i] = rng·random(); - for (j = i; j < NCOL; j++) { - m[n++] = rng·random() + .1; // To ensure not singular + for (j = 0; j < NCOL; j++) { + m[n++] = rng·random() + .1; } } - memcpy(x, y, NCOL * sizeof(*x)); + memcpy(z, y, NCOL * sizeof(*y)); t = clock(); - blas·tpsv(0, NCOL, m, x); + blas·gemv(NROW, NCOL, 2, m, x, 0.0, y); t = clock() - t; + tprof[0] += 1000.*t/CLOCKS_PER_SEC; t = clock(); - cblas_dtpsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, NCOL, m, y, 1); + cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 0.0, z, 1); t = clock() - t; + tprof[1] += 1000.*t/CLOCKS_PER_SEC; for (i = 0; i < NCOL; i++) { - if (math·abs(x[i] - y[i])/math·abs(x[i]) > 1e-5) { - errorf("failure at index %d: %f != %f", i, x[i], y[i]); + if (math·abs(z[i] - y[i])/math·abs(x[i]) > 1e-5) { + errorf("failure at index %d: %f != %f", i, z[i], y[i]); } } } @@ -1529,7 +1562,7 @@ test·level1() #define STEP 1 error -main() +test·argmax() { int i, n; double *x, *y, *w, *z; @@ -1575,9 +1608,12 @@ main() // } } - printf("%f, %f\n", res[0], res[1]); - printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER); - printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER); + return 0; +} +error +main() +{ + test·level2(); return 0; } -- cgit v1.2.1