aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-10 14:22:50 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-10 14:22:50 -0700
commit65fe4a1ddd852c9c702ae008c3b880a20b84d8e9 (patch)
tree1363fee5e6af6a3902925981df4c02b0f85134cc
parenteaa498806479e30a6a825afaff63e2a1fe5702f9 (diff)
begun work on making level 2 strided
-rw-r--r--include/libmath.h2
-rw-r--r--sys/libmath/blas.c187
2 files changed, 99 insertions, 90 deletions
diff --git a/include/libmath.h b/include/libmath.h
index b148065..5a7dc4e 100644
--- a/include/libmath.h
+++ b/include/libmath.h
@@ -160,7 +160,7 @@ int blas·argmin(int len, double *x, int inc);
/* level 2 */
void blas·tpmv(blas·Flag f, int n, double *m, double *x);
-error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ;
+error blas·gemv(int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy) ;
void blas·tpsv(blas·Flag f, int n, double *m, double *x);
void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m);
void blas·her(int n, double a, double *x, double *m);
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index 18a43a6..63f8856 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -483,7 +483,7 @@ blas·axpy(int len, double a, double *x, int incx, double *y, int incy)
static
double
-dot_kernel8_avx2(int len, double *x, double *y)
+dot_kernel8_fma3(int len, double *x, double *y)
{
register int i;
__m256d sum[4];
@@ -494,10 +494,10 @@ dot_kernel8_avx2(int len, double *x, double *y)
}
for (i = 0; i < len; i += 16) {
- sum[0] += _mm256_loadu_pd(x+i+0) * _mm256_loadu_pd(y+i+0);
- sum[1] += _mm256_loadu_pd(x+i+4) * _mm256_loadu_pd(y+i+4);
- sum[2] += _mm256_loadu_pd(x+i+8) * _mm256_loadu_pd(y+i+8);
- sum[3] += _mm256_loadu_pd(x+i+12) * _mm256_loadu_pd(y+i+12);
+ sum[0] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+0), _mm256_loadu_pd(y+i+0), sum[0]);
+ sum[1] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+4), _mm256_loadu_pd(y+i+4), sum[1]);
+ sum[2] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+8), _mm256_loadu_pd(y+i+8), sum[2]);
+ sum[3] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+12), _mm256_loadu_pd(y+i+12), sum[3]);
}
sum[0] += sum[1] + sum[2] + sum[3];
@@ -507,7 +507,7 @@ dot_kernel8_avx2(int len, double *x, double *y)
static
double
-dot_kernel8_fma3(int len, double *x, double *y)
+dot_kernel8_avx2(int len, double *x, double *y)
{
register int i;
__m256d sum[4];
@@ -518,10 +518,10 @@ dot_kernel8_fma3(int len, double *x, double *y)
}
for (i = 0; i < len; i += 16) {
- sum[0] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+0), _mm256_loadu_pd(y+i+0), sum[0]);
- sum[1] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+4), _mm256_loadu_pd(y+i+4), sum[1]);
- sum[2] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+8), _mm256_loadu_pd(y+i+8), sum[2]);
- sum[3] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+12), _mm256_loadu_pd(y+i+12), sum[3]);
+ sum[0] += _mm256_loadu_pd(x+i+0) * _mm256_loadu_pd(y+i+0);
+ sum[1] += _mm256_loadu_pd(x+i+4) * _mm256_loadu_pd(y+i+4);
+ sum[2] += _mm256_loadu_pd(x+i+8) * _mm256_loadu_pd(y+i+8);
+ sum[3] += _mm256_loadu_pd(x+i+12) * _mm256_loadu_pd(y+i+12);
}
sum[0] += sum[1] + sum[2] + sum[3];
@@ -561,6 +561,7 @@ blas·dot(int len, double *x, int incx, double *y, int incy)
if (incx == 1 && incy == 1) {
n = EVEN_BY(len, 16);
sum[0] = dot_kernel8_fma3(n, x, y);
+
x += n;
y += n;
} else {
@@ -576,7 +577,7 @@ blas·dot(int len, double *x, int incx, double *y, int incy)
}
}
- for (; i < len; i++, x += incx, y += incy) {
+ for (; n < len; n++, x += incx, y += incy) {
sum[0] += x[0] * y[0];
}
@@ -1259,6 +1260,10 @@ gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y)
r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c);
}
+ r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1));
+ r128 = _mm_hadd_pd(r128, r128);
+
+ *y = r128[0];
*y = hsum_avx2(r256);
}
@@ -1278,95 +1283,99 @@ gemv_1xN_kernel4(int n, double *row, double *x, double *y)
}
error
-blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y)
+blas·gemv(int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy)
{
int c, r, nr, nc;
double *row[8], res[8];
+ enum {
+ err·nil,
+ err·incm,
+ };
- nc = EVEN_BY(ncol, 4);
-
- nr = EVEN_BY(nrow, 8);
- for (r = 0; r < nr; r += 8) {
- /* assumes row major layout */
- row[0] = m + ((r+0) * ncol);
- row[1] = m + ((r+1) * ncol);
- row[2] = m + ((r+2) * ncol);
- row[3] = m + ((r+3) * ncol);
- row[4] = m + ((r+4) * ncol);
- row[5] = m + ((r+5) * ncol);
- row[6] = m + ((r+6) * ncol);
- row[7] = m + ((r+7) * ncol);
-
- gemv_8xN_kernel4_avx2(nc, row, x, res);
-
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
- res[1] += row[1][c]*x[c];
- res[2] += row[2][c]*x[c];
- res[3] += row[3][c]*x[c];
- res[4] += row[4][c]*x[c];
- res[5] += row[5][c]*x[c];
- res[6] += row[6][c]*x[c];
- res[7] += row[7][c]*x[c];
- }
+ if (incm < ncol) {
+ errorf("aliased matrix: inc = %d < ncols = %d", incm, ncol);
+ return err·incm;
+ }
- y[r+0] = a*res[0] + b*y[r+0];
- y[r+1] = a*res[1] + b*y[r+1];
- y[r+2] = a*res[2] + b*y[r+2];
- y[r+3] = a*res[3] + b*y[r+3];
- y[r+4] = a*res[4] + b*y[r+4];
- y[r+5] = a*res[5] + b*y[r+5];
- y[r+6] = a*res[6] + b*y[r+6];
- y[r+7] = a*res[7] + b*y[r+7];
- }
-
- nr = EVEN_BY(nrow, 4);
- for (; r < nr; r += 4) {
- /* assumes row major layout */
- row[0] = m + ((r+0) * ncol);
- row[1] = m + ((r+1) * ncol);
- row[2] = m + ((r+2) * ncol);
- row[3] = m + ((r+3) * ncol);
-
- gemv_4xN_kernel4_avx2(nc, row, x, res);
-
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
- res[1] += row[1][c]*x[c];
- res[2] += row[2][c]*x[c];
- res[3] += row[3][c]*x[c];
+ if (incx == 1 && incy == 1) {
+ nc = EVEN_BY(ncol, 4);
+
+ nr = EVEN_BY(nrow, 8);
+ for (r = 0; r < nr; r += 8) {
+ row[0] = m + ((r+0) * incm);
+ row[1] = m + ((r+1) * incm);
+ row[2] = m + ((r+2) * incm);
+ row[3] = m + ((r+3) * incm);
+ row[4] = m + ((r+4) * incm);
+ row[5] = m + ((r+5) * incm);
+ row[6] = m + ((r+6) * incm);
+ row[7] = m + ((r+7) * incm);
+
+ gemv_8xN_kernel4_avx2(nc, row, x, res);
+
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c]*x[c];
+ res[1] += row[1][c]*x[c];
+ res[2] += row[2][c]*x[c];
+ res[3] += row[3][c]*x[c];
+ res[4] += row[4][c]*x[c];
+ res[5] += row[5][c]*x[c];
+ res[6] += row[6][c]*x[c];
+ res[7] += row[7][c]*x[c];
+ }
+
+ y[r+0] = a*res[0] + b*y[r+0];
+ y[r+1] = a*res[1] + b*y[r+1];
+ y[r+2] = a*res[2] + b*y[r+2];
+ y[r+3] = a*res[3] + b*y[r+3];
+ y[r+4] = a*res[4] + b*y[r+4];
+ y[r+5] = a*res[5] + b*y[r+5];
+ y[r+6] = a*res[6] + b*y[r+6];
+ y[r+7] = a*res[7] + b*y[r+7];
}
- y[r+0] = a*res[0] + b*y[r+0];
- y[r+1] = a*res[1] + b*y[r+1];
- y[r+2] = a*res[2] + b*y[r+2];
- y[r+3] = a*res[3] + b*y[r+3];
- }
+ nr = EVEN_BY(nrow, 4);
+ for (; r < nr; r += 4) {
+ row[0] = m + ((r+0) * incm);
+ row[1] = m + ((r+1) * incm);
+ row[2] = m + ((r+2) * incm);
+ row[3] = m + ((r+3) * incm);
+
+ gemv_4xN_kernel4_avx2(nc, row, x, res);
- nr = EVEN_BY(nrow, 2);
- for (; r < nr; r += 2) {
- row[0] = m + ((r+0) * ncol);
- row[1] = m + ((r+1) * ncol);
- gemv_2xN_kernel4_avx2(nc, row, x, res);
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c]*x[c];
+ res[1] += row[1][c]*x[c];
+ res[2] += row[2][c]*x[c];
+ res[3] += row[3][c]*x[c];
+ }
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
- res[1] += row[1][c]*x[c];
+ y[r+0] = a*res[0] + b*y[r+0];
+ y[r+1] = a*res[1] + b*y[r+1];
+ y[r+2] = a*res[2] + b*y[r+2];
+ y[r+3] = a*res[3] + b*y[r+3];
}
- y[r+0] = a*res[0] + b*y[r+0];
- y[r+1] = a*res[1] + b*y[r+1];
- }
+ nr = EVEN_BY(nrow, 2);
+ for (; r < nr; r += 2) {
+ row[0] = m + ((r+0) * incm);
+ row[1] = m + ((r+1) * incm);
+ gemv_2xN_kernel4_avx2(nc, row, x, res);
- for (; r < nrow; r++) {
- row[0] = m + ((r+0) * ncol);
- gemv_1xN_kernel4_avx2(nc, row[0], x, res);
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c]*x[c];
+ res[1] += row[1][c]*x[c];
+ }
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
+ y[r+0] = a*res[0] + b*y[r+0];
+ y[r+1] = a*res[1] + b*y[r+1];
}
- y[r] = a*res[0] + b*y[r];
+ for (; r < nrow; r++) {
+ row[0] = m + ((r+0) * ncol);
+ res[0] = blas·dot(ncol, row[0], 1, x, 1);
+ y[r] = a*res[0] + b*y[r];
+ }
}
return 0;
@@ -1504,8 +1513,8 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
}
#define NITER 1000
-#define NCOL 1005
-#define NROW 1005
+#define NCOL 2005
+#define NROW 2005
error
test·level3()
@@ -1602,13 +1611,13 @@ test·level2()
memcpy(z, y, NCOL * sizeof(*y));
t = clock();
- blas·gemv(NROW, NCOL, 2, m, x, 0.0, y);
+ blas·gemv(NROW, NCOL, 2, m, NCOL, x, 1, 1.0, y, 1);
t = clock() - t;
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
t = clock();
- cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 0.0, z, 1);
+ cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 1.0, z, 1);
t = clock() - t;
tprof[1] += 1000.*t/CLOCKS_PER_SEC;