aboutsummaryrefslogtreecommitdiff
path: root/sys
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-04-30 13:00:53 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-04-30 13:00:53 -0700
commitf9873bfabc066f05ece6510d5c016f5b960d255a (patch)
tree3071a1ed8a86bf2ddda45d097eddaecf8a94c0b3 /sys
parent3c32ffdc8e3552aa58bbb8cdf7757ae808ec7eb6 (diff)
chore: broke out blas-like interface into its own file
Diffstat (limited to 'sys')
-rw-r--r--sys/libmath/blas.c380
-rw-r--r--sys/libmath/rules.mk2
-rw-r--r--sys/libmath/test.c46
3 files changed, 418 insertions, 10 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
new file mode 100644
index 0000000..a008ffb
--- /dev/null
+++ b/sys/libmath/blas.c
@@ -0,0 +1,380 @@
+#include <u.h>
+#include <libn.h>
+#include <vendor/cblas.h>
+
+#include <x86intrin.h>
+
+#include <time.h>
+
+// -----------------------------------------------------------------------
+// Level One
+
+/*
+ * Scale vector
+ * x = ax
+ */
+
+static
+void
+scale_kernel8_avx2(int n, double *x, double a)
+{
+ __m128d a128;
+ __m256d a256;
+ register int i;
+
+ a128 = _mm_load_sd(&a);
+ a256 = _mm256_broadcastsd_pd(a128);
+ for (i = 0; i < n; i += 8) {
+ _mm256_storeu_pd(x+i+0, a256 * _mm256_loadu_pd(x+i+0));
+ _mm256_storeu_pd(x+i+4, a256 * _mm256_loadu_pd(x+i+4));
+ }
+}
+
+static
+void
+scale_kernel8(int n, double *x, double a)
+{
+ register int i;
+ for (i = 0; i < n; i += 8) {
+ x[i+0] *= a;
+ x[i+1] *= a;
+ x[i+2] *= a;
+ x[i+3] *= a;
+ x[i+4] *= a;
+ x[i+5] *= a;
+ x[i+6] *= a;
+ x[i+7] *= a;
+ }
+}
+
+void
+blas·scalevec(int len, double *x, double a)
+{
+ int n;
+
+ n = len & ~7;
+ scale_kernel8_avx2(n, x, a);
+
+ for (; n < len; n++) {
+ x[n] *= a;
+ }
+}
+
+/************************************************
+ * Daxpy
+ * y = ax + y
+ ***********************************************/
+
+static
+void
+daxpy_kernel8_avx2(int n, double a, double *x, double *y)
+{
+ __m128d a128;
+ __m256d a256;
+ register int i;
+
+ a128 = _mm_load_sd(&a);
+ a256 = _mm256_broadcastsd_pd(a128);
+ for (i = 0; i < n; i += 8) {
+ _mm256_storeu_pd(y+i+0, _mm256_loadu_pd(y+i+0) + a256 * _mm256_loadu_pd(x+i+0));
+ _mm256_storeu_pd(y+i+4, _mm256_loadu_pd(y+i+4) + a256 * _mm256_loadu_pd(x+i+4));
+ }
+}
+
+static
+void
+daxpy_kernel8(int n, double a, double *x, double *y)
+{
+ register int i;
+ for (i = 0; i < n; i += 8) {
+ y[i+0] += a*x[i+0];
+ y[i+1] += a*x[i+1];
+ y[i+2] += a*x[i+2];
+ y[i+3] += a*x[i+3];
+ y[i+4] += a*x[i+4];
+ y[i+5] += a*x[i+5];
+ y[i+6] += a*x[i+6];
+ y[i+7] += a*x[i+7];
+ }
+}
+
+void
+blas·daxpy(int len, double a, double *x, double *y)
+{
+ int n;
+
+ n = len & ~7;
+ daxpy_kernel8_avx2(n, a, x, y);
+
+ for (; n < len; n++) {
+ y[n] += a*x[n];
+ }
+}
+
+/************************************************
+ * Dot product
+ * x·y
+ ***********************************************/
+
+static
+double
+dot_kernel8_avx2(int len, double *x, double *y)
+{
+ register int i;
+ __m256d sum[4];
+ __m128d res;
+
+ for (i = 0; i < arrlen(sum); i++) {
+ sum[i] = _mm256_setzero_pd();
+ }
+
+ for (i = 0; i < len; i += 16) {
+ sum[0] += _mm256_loadu_pd(x+i+0) * _mm256_loadu_pd(y+i+0);
+ sum[1] += _mm256_loadu_pd(x+i+4) * _mm256_loadu_pd(y+i+4);
+ sum[2] += _mm256_loadu_pd(x+i+8) * _mm256_loadu_pd(y+i+8);
+ sum[3] += _mm256_loadu_pd(x+i+12) * _mm256_loadu_pd(y+i+12);
+ }
+
+ sum[0] += sum[1] + sum[2] + sum[3];
+
+ res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1));
+ res = _mm_hadd_pd(res, res);
+
+ return res[0];
+}
+
+static
+double
+dot_kernel8_fma3(int len, double *x, double *y)
+{
+ register int i;
+ __m256d sum[4];
+ __m128d res;
+
+ for (i = 0; i < arrlen(sum); i++) {
+ sum[i] = _mm256_setzero_pd();
+ }
+
+ for (i = 0; i < len; i += 16) {
+ sum[0] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+0), _mm256_loadu_pd(y+i+0), sum[0]);
+ sum[1] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+4), _mm256_loadu_pd(y+i+4), sum[1]);
+ sum[2] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+8), _mm256_loadu_pd(y+i+8), sum[2]);
+ sum[3] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+12), _mm256_loadu_pd(y+i+12), sum[3]);
+ }
+
+ sum[0] += sum[1] + sum[2] + sum[3];
+
+ res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1));
+ res = _mm_hadd_pd(res, res);
+
+ return res[0];
+}
+
+static
+double
+dot_kernel8(int len, double *x, double *y)
+{
+ double res;
+ register int i;
+
+ for (i = 0; i < len; i += 8) {
+ res += x[i] * y[i] +
+ x[i+1] * y[i+1] +
+ x[i+2] * y[i+2] +
+ x[i+3] * y[i+3] +
+ x[i+4] * y[i+4] +
+ x[i+5] * y[i+5] +
+ x[i+6] * y[i+6] +
+ x[i+7] * y[i+7];
+ }
+
+ return res;
+}
+
+double
+blas·dot(int len, double *x, double *y)
+{
+ int i, n;
+ double res;
+
+ n = len & ~15; // neat trick
+ res = dot_kernel8_fma3(n, x, y);
+
+ for (i = n; i < len; i++) {
+ res += x[i] * y[i];
+ }
+
+ return res;
+}
+
+// -----------------------------------------------------------------------
+// Level Two
+
+/*
+ * Notation: (number of rows) x (number of columns) _ unroll factor
+ * N => variable we sum over
+ */
+
+
+/*
+ * Affine transformation
+ * y = aMx + by
+ */
+static
+void
+gemv_kernel4xN_4_avx2(int ncol, double **row, double *x, double *y)
+{
+ int c;
+ __m128d hr;
+ __m256d x256, r256[4];
+
+ for (c = 0; c < 4; c++) {
+ r256[c] = _mm256_setzero_pd();
+ }
+
+ for (c = 0; c < ncol; c += 4) {
+ x256 = _mm256_loadu_pd(x+c);
+ r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
+ r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
+ r256[2] += x256 * _mm256_loadu_pd(row[2] + c);
+ r256[3] += x256 * _mm256_loadu_pd(row[3] + c);
+ }
+
+ for (c = 0; c < 4; c++) {
+ hr = _mm_add_pd(_mm256_extractf128_pd(r256[c], 0), _mm256_extractf128_pd(r256[c], 1));
+ hr = _mm_hadd_pd(hr, hr);
+ y[c] = hr[0];
+ }
+}
+
+static
+void
+gemv_kernel4xN_4(int ncol, double **row, double *x, double *y)
+{
+ int c;
+ double res[4];
+
+ res[0] = 0.;
+ res[1] = 0.;
+ res[2] = 0.;
+ res[3] = 0.;
+
+ for (c = 0; c < ncol; c += 4) {
+ res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3];
+ res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3];
+ res[2] += row[2][c+0]*x[c+0] + row[2][c+1]*x[c+1] + row[2][c+2]*x[c+2] + row[2][c+3]*x[c+3];
+ res[3] += row[3][c+0]*x[c+0] + row[3][c+1]*x[c+1] + row[3][c+2]*x[c+2] + row[3][c+3]*x[c+3];
+ }
+
+ y[0] = res[0];
+ y[1] = res[1];
+ y[2] = res[2];
+ y[3] = res[3];
+}
+
+static
+void
+gemv_kernel1xN_4(int ncol, double *row, double *x, double *y)
+{
+ int c;
+ double res;
+
+ res = 0.;
+ for (c = 0; c < ncol; c += 4) {
+ res += row[c+0]*x[c+0] + row[c+1]*x[c+1] + row[c+2]*x[c+2] + row[c+3]*x[c+3];
+ }
+
+ y[0] = res;
+}
+
+error
+blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y)
+{
+ int c, r, nr, nc;
+ double *row[4], res[4];
+
+ nr = nrow & ~3;
+ nc = ncol & ~3;
+ for (r = 0; r < nrow; r += 4) {
+ row[0] = m + (r * (ncol+0));
+ row[1] = m + (r * (ncol+1));
+ row[2] = m + (r * (ncol+2));
+ row[3] = m + (r * (ncol+3));
+
+ gemv_kernel4xN_4_avx2(ncol, row, x + r, res);
+
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c];
+ res[1] += row[1][c];
+ res[2] += row[2][c];
+ res[3] += row[3][c];
+ }
+
+ y[r+0] = res[0] + b*y[r+0];
+ y[r+1] = res[1] + b*y[r+1];
+ y[r+2] = res[2] + b*y[r+2];
+ y[r+3] = res[3] + b*y[r+3];
+ }
+
+ for (; r < nrow; r++) {
+ gemv_kernel1xN_4(nrow, m + (r * ncol), x + r, res);
+ y[r] = res[0] + b*y[r];
+ }
+
+ return 0;
+}
+
+/*
+ * rank one addition
+ * M = axy + M
+ */
+
+#define NITER 50
+#define NCOL 1000
+#define NROW 1000
+
+error
+main()
+{
+ int i;
+ clock_t t;
+ double res;
+
+ double *x, *y, *m;
+
+ openblas_set_num_threads(1);
+
+ x = malloc(sizeof(*x)*NCOL);
+ y = malloc(sizeof(*x)*NCOL);
+ m = malloc(sizeof(*x)*NCOL*NROW);
+
+ for (i = 0; i < NCOL; i++) {
+ y[i] = i;
+ }
+
+ t = clock();
+ for (i = 0; i < NITER; i++) {
+ cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 1.5, m, NCOL, x, 1, 2.5, y, 1);
+ }
+ t = clock() - t;
+ res = cblas_ddot(NROW, y, 1, y, 1);
+
+ printf("the result is %f\n", res);
+ printf("time elapsed (blas): %fms\n", 1000.*t/CLOCKS_PER_SEC);
+
+ for (i = 0; i < NCOL; i++) {
+ y[i] = i;
+ }
+
+ t = clock();
+ for (i = 0; i < NITER; i++) {
+ blas·gemv(NROW, NCOL, 1.5, m, x, 2.5, y);
+ }
+ t = clock() - t;
+ res = blas·dot(NCOL, y, y);
+
+ printf("the dot product is %f\n", res);
+ printf("time elapsed (naive): %fms\n", 1000.*t/CLOCKS_PER_SEC);
+
+ return 0;
+}
diff --git a/sys/libmath/rules.mk b/sys/libmath/rules.mk
index 3bf8132..9f02522 100644
--- a/sys/libmath/rules.mk
+++ b/sys/libmath/rules.mk
@@ -19,7 +19,7 @@ LIBS_$(d) :=
LIBS_$(d) := $(patsubst $(SRC_DIR)/%, $(OBJ_DIR)/%, $(LIBS_$(d)))
LIBS := $(LIBS) $(LIBS_$(d))
-BINS_$(d) := $(d)/test
+BINS_$(d) := $(d)/blas
BINS_$(d) := $(patsubst $(SRC_DIR)/%, $(OBJ_DIR)/%, $(BINS_$(d)))
BINS := $(BINS) $(BINS_$(d))
diff --git a/sys/libmath/test.c b/sys/libmath/test.c
index 4978123..3dfaa31 100644
--- a/sys/libmath/test.c
+++ b/sys/libmath/test.c
@@ -302,9 +302,9 @@ math·freemtx(math·Vec *m)
return 0;
}
-/*
+/************************************************
* multiply matrix to vector
- */
+ ***********************************************/
/*
* Notation: (number of rows) x (number of columns) _ unroll factor
@@ -312,15 +312,37 @@ math·freemtx(math·Vec *m)
*/
static
void
-mtxvec_kernel4xN_4(int ncol, double **a, double *x, double *y)
+mtxvec_kernel4xN_4_avx2(int ncol, double **row, double *x, double *y)
{
int c;
- double *row[4], res[4];
+ __m128d hr;
+ __m256d x256, r256[4];
+
+ for (c = 0; c < 4; c++) {
+ r256[c] = _mm256_setzero_pd();
+ }
+
+ for (c = 0; c < ncol; c += 4) {
+ x256 = _mm256_loadu_pd(x+c);
+ r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
+ r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
+ r256[2] += x256 * _mm256_loadu_pd(row[2] + c);
+ r256[3] += x256 * _mm256_loadu_pd(row[3] + c);
+ }
- row[0] = a[0];
- row[1] = a[1];
- row[2] = a[2];
- row[3] = a[3];
+ for (c = 0; c < 4; c++) {
+ hr = _mm_add_pd(_mm256_extractf128_pd(r256[c], 0), _mm256_extractf128_pd(r256[c], 1));
+ hr = _mm_hadd_pd(hr, hr);
+ y[c] = hr[0];
+ }
+}
+
+static
+void
+mtxvec_kernel4xN_4(int ncol, double **row, double *x, double *y)
+{
+ int c;
+ double res[4];
res[0] = 0.;
res[1] = 0.;
@@ -370,7 +392,7 @@ math·mtxvec(math·Mtx m, double a, math·Vec x, double b, math·Vec y)
row[2] = m.d + (r * (m.dim[1]+2));
row[3] = m.d + (r * (m.dim[1]+3));
- mtxvec_kernel4xN_4(ncol, row, x.d + r, res);
+ mtxvec_kernel4xN_4_avx2(ncol, row, x.d + r, res);
for (c = ncol; c < m.dim[1]; c++) {
res[0] += row[0][c];
@@ -393,8 +415,13 @@ math·mtxvec(math·Mtx m, double a, math·Vec x, double b, math·Vec y)
return 0;
}
+/************************************************
+ * add matrix to vector outerproduct
+ ***********************************************/
+
#define NITER 50
+#if 0
error
main()
{
@@ -441,3 +468,4 @@ main()
return 0;
}
+#endif