aboutsummaryrefslogtreecommitdiff
path: root/sys
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-09 16:35:04 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-09 16:35:04 -0700
commiteaa498806479e30a6a825afaff63e2a1fe5702f9 (patch)
tree9eff26f35774cf31443c8477cc4f0e354bf6c506 /sys
parent3be8d6291c95bada9576963770bb9c988708ecda (diff)
added helper function for horizontal sums
Diffstat (limited to 'sys')
-rw-r--r--sys/libmath/blas.c187
1 files changed, 160 insertions, 27 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index 7056c08..18a43a6 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -9,6 +9,26 @@
#define EVEN_BY(x, n) (x) & ~((n)-1)
// -----------------------------------------------------------------------
+// misc functions
+
+static
+inline
+double
+hsum_avx2(__m256d x)
+{
+ __m128d lo128, hi128, hi64;
+
+ lo128 = _mm256_castpd256_pd128(x);
+ hi128 = _mm256_extractf128_pd(x, 1);
+ lo128 = _mm_add_pd(lo128, hi128);
+
+ hi64 = _mm_unpackhi_pd(lo128, lo128);
+
+ return _mm_cvtsd_f64(_mm_add_sd(lo128, hi64));
+}
+
+
+// -----------------------------------------------------------------------
// level one
/*
@@ -482,10 +502,7 @@ dot_kernel8_avx2(int len, double *x, double *y)
sum[0] += sum[1] + sum[2] + sum[3];
- res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1));
- res = _mm_hadd_pd(res, res);
-
- return res[0];
+ return hsum_avx2(sum[0]);
}
static
@@ -509,10 +526,7 @@ dot_kernel8_fma3(int len, double *x, double *y)
sum[0] += sum[1] + sum[2] + sum[3];
- res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1));
- res = _mm_hadd_pd(res, res);
-
- return res[0];
+ return hsum_avx2(sum[0]);
}
static
@@ -604,10 +618,7 @@ sum_kernel8_avx2(int len, double *x)
sum[0] += sum[1];
- res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1));
- res = _mm_hadd_pd(res, res);
-
- return res[0];
+ return hsum_avx2(sum[0]);
}
static
@@ -1108,9 +1119,44 @@ blas·tpsv(blas·Flag f, int n, double *m, double *x)
* general affine transformation
* y = aMx + by
*/
+
+static
+void
+gemv_8xN_kernel4_avx2(int ncol, double *row[8], double *x, double *y)
+{
+ int c;
+ __m128d hr;
+ __m256d x256, r256[8];
+
+ for (c = 0; c < 8; c++) {
+ r256[c] = _mm256_setzero_pd();
+ }
+
+ for (c = 0; c < ncol; c += 4) {
+ x256 = _mm256_loadu_pd(x+c);
+ r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
+ r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
+ r256[2] += x256 * _mm256_loadu_pd(row[2] + c);
+ r256[3] += x256 * _mm256_loadu_pd(row[3] + c);
+ r256[4] += x256 * _mm256_loadu_pd(row[4] + c);
+ r256[5] += x256 * _mm256_loadu_pd(row[5] + c);
+ r256[6] += x256 * _mm256_loadu_pd(row[6] + c);
+ r256[7] += x256 * _mm256_loadu_pd(row[7] + c);
+ }
+
+ y[0] = hsum_avx2(r256[0]);
+ y[1] = hsum_avx2(r256[1]);
+ y[2] = hsum_avx2(r256[2]);
+ y[3] = hsum_avx2(r256[3]);
+ y[4] = hsum_avx2(r256[4]);
+ y[5] = hsum_avx2(r256[5]);
+ y[6] = hsum_avx2(r256[6]);
+ y[7] = hsum_avx2(r256[7]);
+}
+
static
void
-gemv_4xN_kernel4_avx2(int ncol, double **row, double *x, double *y)
+gemv_4xN_kernel4_avx2(int ncol, double *row[4], double *x, double *y)
{
int c;
__m128d hr;
@@ -1128,16 +1174,15 @@ gemv_4xN_kernel4_avx2(int ncol, double **row, double *x, double *y)
r256[3] += x256 * _mm256_loadu_pd(row[3] + c);
}
- for (c = 0; c < 4; c++) {
- hr = _mm_add_pd(_mm256_extractf128_pd(r256[c], 0), _mm256_extractf128_pd(r256[c], 1));
- hr = _mm_hadd_pd(hr, hr);
- y[c] = hr[0];
- }
+ y[0] = hsum_avx2(r256[0]);
+ y[1] = hsum_avx2(r256[1]);
+ y[2] = hsum_avx2(r256[2]);
+ y[3] = hsum_avx2(r256[3]);
}
static
void
-gemv_4xN_kernel4(int n, double **row, double *x, double *y)
+gemv_4xN_kernel4(int n, double *row[4], double *x, double *y)
{
int c;
double res[4];
@@ -1162,6 +1207,47 @@ gemv_4xN_kernel4(int n, double **row, double *x, double *y)
static
void
+gemv_2xN_kernel4_avx2(int n, double *row[2], double *x, double *y)
+{
+ int c;
+ __m128d hr;
+ __m256d x256, r256[2];
+
+ for (c = 0; c < 2; c++) {
+ r256[c] = _mm256_setzero_pd();
+ }
+
+ for (c = 0; c < n; c += 4) {
+ x256 = _mm256_loadu_pd(x+c);
+ r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
+ r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
+ }
+
+ y[0] = hsum_avx2(r256[0]);
+ y[1] = hsum_avx2(r256[1]);
+}
+
+static
+void
+gemv_2xN_kernel4(int n, double *row[2], double *x, double *y)
+{
+ int c;
+ double res[2];
+
+ res[0] = 0.0;
+ res[1] = 0.0;
+
+ for (c = 0; c < n; c += 4) {
+ res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3];
+ res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3];
+ }
+
+ y[0] = res[0];
+ y[1] = res[1];
+}
+
+static
+void
gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y)
{
int c;
@@ -1173,10 +1259,7 @@ gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y)
r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c);
}
- r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1));
- r128 = _mm_hadd_pd(r128, r128);
-
- *y = r128[0];
+ *y = hsum_avx2(r256);
}
static
@@ -1198,12 +1281,47 @@ error
blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y)
{
int c, r, nr, nc;
- double *row[4], res[4];
+ double *row[8], res[8];
- nr = EVEN_BY(nrow, 4);
nc = EVEN_BY(ncol, 4);
- for (r = 0; r < nr; r += 4) {
+ nr = EVEN_BY(nrow, 8);
+ for (r = 0; r < nr; r += 8) {
+ /* assumes row major layout */
+ row[0] = m + ((r+0) * ncol);
+ row[1] = m + ((r+1) * ncol);
+ row[2] = m + ((r+2) * ncol);
+ row[3] = m + ((r+3) * ncol);
+ row[4] = m + ((r+4) * ncol);
+ row[5] = m + ((r+5) * ncol);
+ row[6] = m + ((r+6) * ncol);
+ row[7] = m + ((r+7) * ncol);
+
+ gemv_8xN_kernel4_avx2(nc, row, x, res);
+
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c]*x[c];
+ res[1] += row[1][c]*x[c];
+ res[2] += row[2][c]*x[c];
+ res[3] += row[3][c]*x[c];
+ res[4] += row[4][c]*x[c];
+ res[5] += row[5][c]*x[c];
+ res[6] += row[6][c]*x[c];
+ res[7] += row[7][c]*x[c];
+ }
+
+ y[r+0] = a*res[0] + b*y[r+0];
+ y[r+1] = a*res[1] + b*y[r+1];
+ y[r+2] = a*res[2] + b*y[r+2];
+ y[r+3] = a*res[3] + b*y[r+3];
+ y[r+4] = a*res[4] + b*y[r+4];
+ y[r+5] = a*res[5] + b*y[r+5];
+ y[r+6] = a*res[6] + b*y[r+6];
+ y[r+7] = a*res[7] + b*y[r+7];
+ }
+
+ nr = EVEN_BY(nrow, 4);
+ for (; r < nr; r += 4) {
/* assumes row major layout */
row[0] = m + ((r+0) * ncol);
row[1] = m + ((r+1) * ncol);
@@ -1225,8 +1343,23 @@ blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double
y[r+3] = a*res[3] + b*y[r+3];
}
+ nr = EVEN_BY(nrow, 2);
+ for (; r < nr; r += 2) {
+ row[0] = m + ((r+0) * ncol);
+ row[1] = m + ((r+1) * ncol);
+ gemv_2xN_kernel4_avx2(nc, row, x, res);
+
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c]*x[c];
+ res[1] += row[1][c]*x[c];
+ }
+
+ y[r+0] = a*res[0] + b*y[r+0];
+ y[r+1] = a*res[1] + b*y[r+1];
+ }
+
for (; r < nrow; r++) {
- row[0] = m + (r * ncol);
+ row[0] = m + ((r+0) * ncol);
gemv_1xN_kernel4_avx2(nc, row[0], x, res);
for (c = nc; c < ncol; c++) {