aboutsummaryrefslogtreecommitdiff
path: root/sys/libmath/blas.c
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-13 08:29:16 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-13 08:29:16 -0700
commitc9d4b2d7dd1d9a46571e5d2b2cf6ce10a9d9ebea (patch)
treead3c1cf1d3295760e7c32d6cdd17846febf1dbea /sys/libmath/blas.c
parentd3241acc69327081c2f9c2b1d9ed4ae96d8f1287 (diff)
unrolling blas level 1 fully works
Diffstat (limited to 'sys/libmath/blas.c')
-rw-r--r--sys/libmath/blas.c1795
1 files changed, 83 insertions, 1712 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index 63f8856..f6eb830 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -1,1761 +1,132 @@
#include <u.h>
#include <libn.h>
-#include <libmath.h>
-#include <vendor/blas/cblas.h>
-
-#include <x86intrin.h>
#include <time.h>
-
-#define EVEN_BY(x, n) (x) & ~((n)-1)
-
-// -----------------------------------------------------------------------
-// misc functions
-
-static
-inline
-double
-hsum_avx2(__m256d x)
-{
- __m128d lo128, hi128, hi64;
-
- lo128 = _mm256_castpd256_pd128(x);
- hi128 = _mm256_extractf128_pd(x, 1);
- lo128 = _mm_add_pd(lo128, hi128);
-
- hi64 = _mm_unpackhi_pd(lo128, lo128);
-
- return _mm_cvtsd_f64(_mm_add_sd(lo128, hi64));
-}
-
-
-// -----------------------------------------------------------------------
-// level one
-
-/*
- * rotate vector
- * x = cos*x + sin*y
- * y = cos*x - sin*y
- */
-
-static
-void
-rot_kernel8_avx2(int n, double *x, double *y, double cos, double sin)
-{
- register int i;
- __m256d x256, y256;
- __m128d cos128, sin128;
- __m256d cos256, sin256;
-
- cos128 = _mm_load_sd(&cos);
- cos256 = _mm256_broadcastsd_pd(cos128);
-
- sin128 = _mm_load_sd(&sin);
- sin256 = _mm256_broadcastsd_pd(sin128);
-
- for (i = 0; i < n; i+=8) {
- x256 = _mm256_loadu_pd(x+i+0);
- y256 = _mm256_loadu_pd(y+i+0);
- _mm256_storeu_pd(x+i+0, cos256 * x256 + sin256 * y256);
- _mm256_storeu_pd(y+i+0, cos256 * y256 - sin256 * x256);
-
- x256 = _mm256_loadu_pd(x+i+4);
- y256 = _mm256_loadu_pd(y+i+4);
- _mm256_storeu_pd(x+i+4, cos256 * x256 + sin256 * y256);
- _mm256_storeu_pd(y+i+4, cos256 * y256 - sin256 * x256);
- }
-}
-
-static
-void
-rot_kernel8(int n, double *x, double *y, double cos, double sin)
-{
- register int i;
- register double tmp;
-
- for (i = 0; i < n; i+=8) {
- tmp = x[i+0], x[i+0] = cos*x[i+0] + sin*y[i+0], y[i+0] = cos*y[i+0] - sin*tmp;
- tmp = x[i+1], x[i+1] = cos*x[i+1] + sin*y[i+1], y[i+1] = cos*y[i+1] - sin*tmp;
- tmp = x[i+2], x[i+2] = cos*x[i+2] + sin*y[i+2], y[i+2] = cos*y[i+2] - sin*tmp;
- tmp = x[i+3], x[i+3] = cos*x[i+3] + sin*y[i+3], y[i+3] = cos*y[i+3] - sin*tmp;
- tmp = x[i+4], x[i+4] = cos*x[i+4] + sin*y[i+4], y[i+4] = cos*y[i+4] - sin*tmp;
- tmp = x[i+5], x[i+5] = cos*x[i+5] + sin*y[i+5], y[i+5] = cos*y[i+5] - sin*tmp;
- tmp = x[i+6], x[i+6] = cos*x[i+6] + sin*y[i+6], y[i+6] = cos*y[i+6] - sin*tmp;
- tmp = x[i+7], x[i+7] = cos*x[i+7] + sin*y[i+7], y[i+7] = cos*y[i+7] - sin*tmp;
- }
-}
-
-void
-blas·rot(int len, double *x, int incx, double *y, int incy, double cos, double sin)
-{
- register int i, n;
- register double tmp;
-
- if (incx == 1 && incy == 1) {
- n = EVEN_BY(len, 8);
- rot_kernel8_avx2(n, x, y, cos, sin);
- x += n;
- y += n;
- } else {
- n = EVEN_BY(len, 4);
- for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
- tmp = x[0*incx], x[0*incx] = cos*x[0*incx] + sin*y[0*incy], y[0*incy] = cos*y[0*incy] - sin*tmp;
- tmp = x[1*incx], x[1*incx] = cos*x[1*incx] + sin*y[1*incy], y[1*incy] = cos*y[1*incy] - sin*tmp;
- tmp = x[2*incx], x[2*incx] = cos*x[2*incx] + sin*y[2*incy], y[2*incy] = cos*y[2*incy] - sin*tmp;
- tmp = x[3*incx], x[3*incx] = cos*x[3*incx] + sin*y[3*incy], y[3*incy] = cos*y[3*incy] - sin*tmp;
- }
- }
-
- for (; n < len; n++, x += incx, y += incy) {
- tmp = x[0], x[0] = cos*x[0] + sin*y[0], y[0] = cos*y[0] - sin*tmp;
- }
-
-}
-
-/*
- * compute givens rotate vector
- * -- --
- * |+cos -sin| | a | = | r |
- * |-sin +cos| | b | = | 0 |
- * -- --
- */
-
-void
-blas·rotg(double *a, double *b, double *cos, double *sin)
-{
- double abs_a, abs_b, r, rho, scale, z;
-
- abs_a = math·abs(*a);
- abs_b = math·abs(*b);
- rho = abs_a > abs_b ? *a : *b;
- scale = abs_a + abs_b;
-
- if (scale == 0) {
- *cos = 1, *sin = 0;
- r = 0.;
- z = 0.;
- } else {
- r = math·sgn(rho) * scale * math·sqrt(math·pow(abs_a/scale, 2) + math·pow(abs_b/scale, 2));
- *cos = *a / r;
- *sin = *b / r;
- if (abs_a > abs_b)
- z = *sin;
- else if (abs_b >= abs_a && *cos != 0)
- z = 1/(*cos);
- else
- z = 1.;
- }
- *a = r;
- *b = z;
-}
-
-/*
- * modified Givens rotation of points in plane
- * operates on len points
- *
- * params = [flag, h11, h12, h21, h22]
- * NOTE: This is row major as opposed to other implementations
- *
- * Flags correspond to:
- * @flag = -1:
- * H -> [ [h11, h12], [h21, h22] ]
- * @flag = 0.0:
- * H -> [ [1, h12], [h21, 1] ]
- * @flag = +1:
- * H -> [ [h11, 1], [-1, h22] ]
- * @flag = -2:
- * H -> [ [1, 0], [0, 1] ]
- * @flag = *
- * return error
- *
- * Replaces:
- * x -> H11 * x + H12 * y
- * y -> H21 * x + H22 * y
- */
-
-static
-void
-rotm_kernel8_avx2(int n, double *x, double *y, double H[4])
-{
- register int i;
- __m256d x256, y256;
- __m256d H256[4];
-
- for (i = 0; i < 4; i++) {
- H256[i] = _mm256_broadcastsd_pd(_mm_load_sd(H+i));
- }
-
- for (i = 0; i < n; i+=8) {
- x256 = _mm256_loadu_pd(x+i+0);
- y256 = _mm256_loadu_pd(y+i+0);
- _mm256_storeu_pd(x+i+0, H256[0] * x256 + H256[1] * y256);
- _mm256_storeu_pd(y+i+0, H256[2] * x256 + H256[3] * y256);
-
- x256 = _mm256_loadu_pd(x+i+4);
- y256 = _mm256_loadu_pd(y+i+4);
- _mm256_storeu_pd(x+i+4, H256[0] * x256 + H256[1] * y256);
- _mm256_storeu_pd(y+i+4, H256[2] * x256 + H256[3] * y256);
- }
-}
-
-static
-void
-rotm_kernel8(int n, double *x, double *y, double H[4])
-{
- register int i;
- register double tmp;
-
- for (i = 0; i < n; i+=8) {
- tmp = x[i+0], x[i+0] = H[0]*x[i+0] + H[1]*y[i+0], y[i+0] = H[2]*tmp + H[3]*y[i+0];
- tmp = x[i+1], x[i+1] = H[0]*x[i+1] + H[1]*y[i+1], y[i+1] = H[2]*tmp + H[3]*y[i+1];
- tmp = x[i+2], x[i+2] = H[0]*x[i+2] + H[1]*y[i+2], y[i+2] = H[2]*tmp + H[3]*y[i+2];
- tmp = x[i+3], x[i+3] = H[0]*x[i+3] + H[1]*y[i+3], y[i+3] = H[2]*tmp + H[3]*y[i+3];
- tmp = x[i+4], x[i+4] = H[0]*x[i+4] + H[1]*y[i+4], y[i+4] = H[2]*tmp + H[3]*y[i+4];
- tmp = x[i+5], x[i+5] = H[0]*x[i+5] + H[1]*y[i+5], y[i+5] = H[2]*tmp + H[3]*y[i+5];
- tmp = x[i+6], x[i+6] = H[0]*x[i+6] + H[1]*y[i+6], y[i+6] = H[2]*tmp + H[3]*y[i+6];
- tmp = x[i+7], x[i+7] = H[0]*x[i+7] + H[1]*y[i+7], y[i+7] = H[2]*tmp + H[3]*y[i+7];
- }
-}
-
-error
-blas·rotm(int len, double *x, int incx, double *y, int incy, double p[5])
-{
- int i, n, flag;
- double tmp, H[4];
-
- flag = math·round(p[0]);
- switch (flag) {
- case -1: H[0] = p[1], H[1] = p[2], H[2] = p[3], H[3] = p[4]; break;
- case 0: H[0] = +1, H[1] = p[2], H[2] = p[3], H[3] = +1; break;
- case +1: H[0] = p[1], H[1] = +1, H[2] = -1, H[3] = p[4]; break;
- case -2: H[0] = +1, H[1] = 0, H[2] = 0, H[3] = +1; break;
- default:
- errorf("rotm: flag '%d' unrecognized", flag);
- return 1;
- }
-
- if (incx == 1 && incy == 1) {
- n = EVEN_BY(len, 8);
- rotm_kernel8_avx2(n, x, y, H);
- x += n;
- y += n;
- } else {
- n = EVEN_BY(len, 4);
- for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
- tmp = x[0*incx], x[0*incx] = H[0]*x[0*incx] + H[1]*y[0*incy], y[0*incy] = H[2]*tmp + H[3]*y[0*incy];
- tmp = x[1*incx], x[1*incx] = H[0]*x[1*incx] + H[1]*y[1*incy], y[1*incy] = H[2]*tmp + H[3]*y[1*incy];
- tmp = x[2*incx], x[2*incx] = H[0]*x[2*incx] + H[1]*y[2*incy], y[2*incy] = H[2]*tmp + H[3]*y[2*incy];
- tmp = x[3*incx], x[3*incx] = H[0]*x[3*incx] + H[1]*y[3*incy], y[3*incy] = H[2]*tmp + H[3]*y[3*incy];
- }
- }
-
- for (; n < len; n++, x += incx, y += incy) {
- tmp = x[0], x[0] = H[0]*x[0] + H[1]*y[0], y[0] = H[2]*tmp + H[3]*y[0];
- }
-
- return 0;
-}
-
-
-/*
- * scale vector
- * x = ax
- */
-
-static
-void
-scale_kernel8_avx2(int n, double *x, double a)
-{
- __m128d a128;
- __m256d a256;
- register int i;
-
- a128 = _mm_load_sd(&a);
- a256 = _mm256_broadcastsd_pd(a128);
- for (i = 0; i < n; i += 8) {
- _mm256_storeu_pd(x+i+0, a256 * _mm256_loadu_pd(x+i+0));
- _mm256_storeu_pd(x+i+4, a256 * _mm256_loadu_pd(x+i+4));
- }
-}
-
-static
-void
-scale_kernel8(int n, double *x, double a)
-{
- register int i;
- for (i = 0; i < n; i += 8) {
- x[i+0] *= a;
- x[i+1] *= a;
- x[i+2] *= a;
- x[i+3] *= a;
- x[i+4] *= a;
- x[i+5] *= a;
- x[i+6] *= a;
- x[i+7] *= a;
- }
-}
-
-void
-blas·scale(int len, double a, double *x, int inc)
-{
- int n, ix;
-
- if (inc == 1) {
- n = EVEN_BY(len, 8);
- scale_kernel8_avx2(n, x, a);
- ix = n;
- } else {
- n = EVEN_BY(len, 4);
- for (ix = 0; ix < n*inc; ix += 4*inc) {
- x[ix+0*inc] *= a;
- x[ix+1*inc] *= a;
- x[ix+2*inc] *= a;
- x[ix+3*inc] *= a;
- }
- }
- for (; n < len; n++, ix += inc) {
- x[ix] *= a;
- }
-}
-
-/*
- * copy
- * y = x
- */
-
-void
-blas·copy(int len, double *x, int incx, double *y, int incy)
-{
- int n, i, ix, iy;
- if (incx == 1 && incy == 1) {
- memcpy(y, x, sizeof(*x) * len);
- return;
- }
-
- n = EVEN_BY(len, 4);
- for (i = 0, incx = 0, incy = 0; i < n; i+=4, ix+=4*incx, iy+=4*incy) {
- y[iy+0*incy] = x[ix+0*incx];
- y[iy+1*incy] = x[ix+1*incx];
- y[iy+2*incy] = x[ix+2*incx];
- y[iy+3*incy] = x[ix+3*incx];
- }
-
- for (; n < len; n++, ix+=incx, iy+=incy) {
- y[iy] = x[ix];
- }
-}
-
-/*
- * swap
- * y <=> x
- */
-
-static
-void
-swap_kernel8_avx2(int n, double *x, double *y)
-{
- register int i;
- __m256d tmp[2];
- for (i = 0; i < n; i += 8) {
- tmp[0] = _mm256_loadu_pd(x+i+0);
- tmp[1] = _mm256_loadu_pd(y+i+0);
- _mm256_storeu_pd(x+i+0, tmp[1]);
- _mm256_storeu_pd(y+i+0, tmp[0]);
-
- tmp[0] = _mm256_loadu_pd(x+i+4);
- tmp[1] = _mm256_loadu_pd(y+i+4);
- _mm256_storeu_pd(x+i+4, tmp[1]);
- _mm256_storeu_pd(y+i+4, tmp[0]);
- }
-}
-
-static
-void
-swap_kernel8(int n, double *x, double *y)
-{
- register int i;
- register double tmp;
- for (i = 0; i < n; i += 8) {
- tmp = x[i+0], x[i+0] = y[i+0], y[i+0] = tmp;
- tmp = x[i+1], x[i+1] = y[i+1], y[i+1] = tmp;
- tmp = x[i+2], x[i+2] = y[i+2], y[i+2] = tmp;
- tmp = x[i+3], x[i+3] = y[i+3], y[i+3] = tmp;
- tmp = x[i+4], x[i+4] = y[i+4], y[i+4] = tmp;
- tmp = x[i+5], x[i+5] = y[i+5], y[i+5] = tmp;
- tmp = x[i+6], x[i+6] = y[i+6], y[i+6] = tmp;
- tmp = x[i+7], x[i+7] = y[i+7], y[i+7] = tmp;
- }
-}
-
-void
-blas·swap(int len, double *x, int incx, double *y, int incy)
-{
- int n, i, ix, iy;
- double tmp;
-
- if (incx == 1 && incy == 1) {
- n = EVEN_BY(len, 8);
- swap_kernel8(n, x, y);
- ix = n;
- iy = n;
- } else {
- n = EVEN_BY(len, 4);
- for (i = 0, ix = 0, iy = 0; i < n; i += 4, ix += 4*incx, iy += 4*incy) {
- tmp = x[ix + 0*incx], x[ix + 0*incx] = y[iy + 0*incy], y[iy + 0*incy] = tmp;
- tmp = x[ix + 1*incx], x[ix + 1*incx] = y[iy + 1*incy], y[iy + 1*incy] = tmp;
- tmp = x[ix + 2*incx], x[ix + 2*incx] = y[iy + 2*incy], y[iy + 2*incy] = tmp;
- tmp = x[ix + 3*incx], x[ix + 3*incx] = y[iy + 3*incy], y[iy + 3*incy] = tmp;
- }
- }
-
- for (; n < len; n++, ix += incx, iy += incy) {
- tmp = x[ix], x[ix] = y[iy], y[iy] = tmp;
- }
-}
-
-/*
- * daxpy
- * y = ax + y
- */
-
-
-static
-void
-axpy_kernel8_avx2(int n, double a, double *x, double *y)
-{
- __m128d a128;
- __m256d a256;
- register int i;
-
- a128 = _mm_load_sd(&a);
- a256 = _mm256_broadcastsd_pd(a128);
- for (i = 0; i < n; i += 8) {
- _mm256_storeu_pd(y+i+0, _mm256_loadu_pd(y+i+0) + a256 * _mm256_loadu_pd(x+i+0));
- _mm256_storeu_pd(y+i+4, _mm256_loadu_pd(y+i+4) + a256 * _mm256_loadu_pd(x+i+4));
- }
-}
-
-static
-void
-axpy_kernel8(int n, double a, double *x, double *y)
-{
- register int i;
- for (i = 0; i < n; i += 8) {
- y[i+0] += a*x[i+0];
- y[i+1] += a*x[i+1];
- y[i+2] += a*x[i+2];
- y[i+3] += a*x[i+3];
- y[i+4] += a*x[i+4];
- y[i+5] += a*x[i+5];
- y[i+6] += a*x[i+6];
- y[i+7] += a*x[i+7];
- }
-}
-
-void
-blas·axpy(int len, double a, double *x, int incx, double *y, int incy)
-{
- int n, i;
-
- if (incx == 1 && incy == 1) {
- n = EVEN_BY(len, 8);
- axpy_kernel8_avx2(n, a, x, y);
- x += n;
- y += n;
- } else {
- n = EVEN_BY(len, 4);
- for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
- y[0*incy] += a*x[0*incx];
- y[1*incy] += a*x[1*incx];
- y[2*incy] += a*x[2*incx];
- y[3*incy] += a*x[3*incx];
- }
- }
-
- for (; n < len; n++, x+=incx, y+=incy) {
- *y += a*(*x);
- }
-}
-
-/*
- * dot product
- * x·y
- */
-
-static
-double
-dot_kernel8_fma3(int len, double *x, double *y)
-{
- register int i;
- __m256d sum[4];
- __m128d res;
-
- for (i = 0; i < arrlen(sum); i++) {
- sum[i] = _mm256_setzero_pd();
- }
-
- for (i = 0; i < len; i += 16) {
- sum[0] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+0), _mm256_loadu_pd(y+i+0), sum[0]);
- sum[1] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+4), _mm256_loadu_pd(y+i+4), sum[1]);
- sum[2] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+8), _mm256_loadu_pd(y+i+8), sum[2]);
- sum[3] = _mm256_fmadd_pd(_mm256_loadu_pd(x+i+12), _mm256_loadu_pd(y+i+12), sum[3]);
- }
-
- sum[0] += sum[1] + sum[2] + sum[3];
-
- return hsum_avx2(sum[0]);
-}
-
-static
-double
-dot_kernel8_avx2(int len, double *x, double *y)
-{
- register int i;
- __m256d sum[4];
- __m128d res;
-
- for (i = 0; i < arrlen(sum); i++) {
- sum[i] = _mm256_setzero_pd();
- }
-
- for (i = 0; i < len; i += 16) {
- sum[0] += _mm256_loadu_pd(x+i+0) * _mm256_loadu_pd(y+i+0);
- sum[1] += _mm256_loadu_pd(x+i+4) * _mm256_loadu_pd(y+i+4);
- sum[2] += _mm256_loadu_pd(x+i+8) * _mm256_loadu_pd(y+i+8);
- sum[3] += _mm256_loadu_pd(x+i+12) * _mm256_loadu_pd(y+i+12);
- }
-
- sum[0] += sum[1] + sum[2] + sum[3];
-
- return hsum_avx2(sum[0]);
-}
-
-static
-double
-dot_kernel8(int len, double *x, double *y)
-{
- double res;
- register int i;
-
- for (i = 0; i < len; i += 8) {
- res += x[i] * y[i] +
- x[i+1] * y[i+1] +
- x[i+2] * y[i+2] +
- x[i+3] * y[i+3] +
- x[i+4] * y[i+4] +
- x[i+5] * y[i+5] +
- x[i+6] * y[i+6] +
- x[i+7] * y[i+7];
- }
-
- return res;
-}
-
-double
-blas·dot(int len, double *x, int incx, double *y, int incy)
-{
- int i, n;
- double mul[4], sum[2];
- if (len == 0) return 0;
-
- sum[0] = 0, sum[1] = 0;
- if (incx == 1 && incy == 1) {
- n = EVEN_BY(len, 16);
- sum[0] = dot_kernel8_fma3(n, x, y);
-
- x += n;
- y += n;
- } else {
- n = EVEN_BY(len, 4);
- for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
- mul[0] = x[0*incx] * y[0*incy];
- mul[1] = x[1*incx] * y[1*incy];
- mul[2] = x[2*incx] * y[2*incy];
- mul[3] = x[3*incx] * y[3*incy];
-
- sum[0] += mul[0] + mul[2];
- sum[1] += mul[1] + mul[3];
- }
- }
-
- for (; n < len; n++, x += incx, y += incy) {
- sum[0] += x[0] * y[0];
- }
-
- sum[0] += sum[1];
- return sum[0];
-}
-
-/*
- * euclidean norm
- * ||x||
- */
-double
-blas·norm(int len, double *x, int incx)
-{
- double res;
-
- res = blas·dot(len, x, incx, x, incx);
- res = math·sqrt(res);
-
- return res;
-}
-
-static
-double
-sum_kernel8_avx2(int len, double *x)
-{
- register int i;
- __m256d sum[2];
- __m128d res;
-
- for (i = 0; i < arrlen(sum); i++) {
- sum[i] = _mm256_setzero_pd();
- }
-
- for (i = 0; i < len; i += 8) {
- sum[0] += _mm256_loadu_pd(x+i+0);
- sum[1] += _mm256_loadu_pd(x+i+4);
- }
-
- sum[0] += sum[1];
-
- return hsum_avx2(sum[0]);
-}
-
-static
-double
-sum_kernel8(int len, double *x, double *y)
-{
- double res;
- register int i;
-
- for (i = 0; i < len; i += 8) {
- res += x[i] +
- x[i+1] +
- x[i+2] +
- x[i+3] +
- x[i+4] +
- x[i+5] +
- x[i+6] +
- x[i+7];
- }
-
- return res;
-}
-
-
-/*
- * L1 norm
- * sum(x_i)
- */
-double
-blas·sum(int len, double *x, int inc)
-{
- int i, n;
- double res;
-
- if (len == 0) return 0;
-
- if (inc == 1) {
- n = EVEN_BY(len, 8);
- res = sum_kernel8_avx2(n, x);
- } else {
- n = EVEN_BY(len, 4);
- for (i = 0; i < n; i++, x += 4*inc) {
- res += x[0*inc];
- res += x[1*inc];
- res += x[2*inc];
- res += x[3*inc];
- }
- }
-
- for (i = n; i < len; i++, x += inc) {
- res += x[0];
- }
-
- return res;
-}
-
-/*
- * argmax
- * i = argmax(x)
- */
-
-int
-argmax_kernel8_avx2(int n, double *x)
-{
- register int i, msk, maxidx[4];
- __m256d val, cmp, max;
-
- maxidx[0] = 0, maxidx[1] = 0, maxidx[2] = 0, maxidx[3] = 0;
- max = _mm256_loadu_pd(x);
-
- for (i = 0; i < n; i += 4) {
- val = _mm256_loadu_pd(x+i);
- cmp = _mm256_cmp_pd(val, max, _CMP_GT_OQ);
- msk = _mm256_movemask_pd(cmp);
- switch (msk) {
- case 1:
- max = _mm256_blend_pd(max, val, 1);
- maxidx[0] = i;
- break;
-
- case 2:
- max = _mm256_blend_pd(max, val, 2);
- maxidx[1] = i;
- break;
-
- case 3:
- max = _mm256_blend_pd(max, val, 3);
- maxidx[0] = i;
- maxidx[1] = i;
- break;
-
- case 4:
- max = _mm256_blend_pd(max, val, 4);
- maxidx[2] = i;
- break;
-
- case 5:
- max = _mm256_blend_pd(max, val, 5);
- maxidx[2] = i;
- maxidx[0] = i;
- break;
-
- case 6:
- max = _mm256_blend_pd(max, val, 6);
- maxidx[2] = i;
- maxidx[1] = i;
- break;
-
- case 7:
- max = _mm256_blend_pd(max, val, 7);
- maxidx[2] = i;
- maxidx[1] = i;
- maxidx[0] = i;
- break;
-
- case 8:
- max = _mm256_blend_pd(max, val, 8);
- maxidx[3] = i;
- break;
-
- case 9:
- max = _mm256_blend_pd(max, val, 9);
- maxidx[3] = i;
- maxidx[0] = i;
- break;
-
- case 10:
- max = _mm256_blend_pd(max, val, 10);
- maxidx[3] = i;
- maxidx[1] = i;
- break;
-
- case 11:
- max = _mm256_blend_pd(max, val, 11);
- maxidx[3] = i;
- maxidx[1] = i;
- maxidx[0] = i;
- break;
-
- case 12:
- max = _mm256_blend_pd(max, val, 12);
- maxidx[3] = i;
- maxidx[2] = i;
- break;
-
- case 13:
- max = _mm256_blend_pd(max, val, 13);
- maxidx[3] = i;
- maxidx[2] = i;
- maxidx[0] = i;
- break;
-
- case 14:
- max = _mm256_blend_pd(max, val, 14);
- maxidx[3] = i;
- maxidx[2] = i;
- maxidx[1] = i;
- break;
-
- case 15:
- max = _mm256_blend_pd(max, val, 15);
- maxidx[3] = i;
- maxidx[2] = i;
- maxidx[1] = i;
- maxidx[0] = i;
- break;
-
- case 0:
- default: ;
- }
- }
- maxidx[0] = (x[maxidx[0]+0] > x[maxidx[1]+1]) ? maxidx[0]+0 : maxidx[1]+1;
- maxidx[1] = (x[maxidx[2]+2] > x[maxidx[3]+3]) ? maxidx[2]+2 : maxidx[3]+3;
- maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1];
-
- return maxidx[0];
-}
+#include <vendor/blas/cblas.h>
int
-argmax_kernel8(int n, double *x)
+argmax2(int len, double *x)
{
-#define SET(d) idx[d] = d, max[d] = x[d]
-#define PUT(d) if (x[i+d] > max[d]) idx[d] = i+d, max[d] = x[i+d]
- int i, idx[8];
+ int i, ix[8];
+ double *end;
double max[8];
- SET(0); SET(1); SET(2); SET(3);
- SET(4); SET(5); SET(6); SET(7);
-
- for (i = 0; i < n; i += 8) {
- PUT(0); PUT(1); PUT(2); PUT(3);
- PUT(4); PUT(5); PUT(6); PUT(7);
- }
-
- n = 0;
- for (i = 1; i < 8; i++ ) {
- if (max[i] > max[n]) {
- n = i;
- }
- }
- return idx[n];
-#undef PUT
-#undef SET
-}
-
-int
-blas·argmax(int len, double *x, int inc)
-{
- int i, ix, n;
- double max;
-
- if (len == 0) {
- return -1;
- }
-
- if (inc == 1) {
- n = EVEN_BY(len, 8);
- ix = argmax_kernel8_avx2(n, x);
- max = x[ix];
- x += n;
- } else {
- n = EVEN_BY(len, 4);
- ix = 0;
- max = x[ix];
- for (i = 0; i < n; i += 4, x += 4*inc) {
- if (x[0*inc] > max) {
- ix = i;
- max = x[0*inc];
- }
- if (x[1*inc] > max) {
- ix = i+1;
- max = x[1*inc];
- }
- if (x[2*inc] > max) {
- ix = i+2;
- max = x[2*inc];
- }
- if (x[3*inc] > max) {
- ix = i+3;
- max = x[3*inc];
- }
- }
- }
-
- for (; n < len; n++, x += inc) {
- if (*x > max) {
- ix = n;
- max = *x;
- }
- }
-
- return ix;
-}
-
-int
-argmin_kernel8_avx2(int n, double *x)
-{
- register int i, msk, minidx[4];
- __m256d val, cmp, min;
-
- minidx[0] = 0, minidx[1] = 0, minidx[2] = 0, minidx[3] = 0;
- min = _mm256_loadu_pd(x);
-
- for (i = 0; i < n; i += 4) {
- val = _mm256_loadu_pd(x+i);
- cmp = _mm256_cmp_pd(val, min, _CMP_LT_OS);
- msk = _mm256_movemask_pd(cmp);
- switch (msk) {
- case 1:
- min = _mm256_blend_pd(min, val, 1);
- minidx[0] = i;
- break;
-
- case 2:
- min = _mm256_blend_pd(min, val, 2);
- minidx[1] = i;
- break;
-
- case 3:
- min = _mm256_blend_pd(min, val, 3);
- minidx[0] = i;
- minidx[1] = i;
- break;
-
- case 4:
- min = _mm256_blend_pd(min, val, 4);
- minidx[2] = i;
- break;
-
- case 5:
- min = _mm256_blend_pd(min, val, 5);
- minidx[2] = i;
- minidx[0] = i;
- break;
-
- case 6:
- min = _mm256_blend_pd(min, val, 6);
- minidx[2] = i;
- minidx[1] = i;
- break;
-
- case 7:
- min = _mm256_blend_pd(min, val, 7);
- minidx[2] = i;
- minidx[1] = i;
- minidx[0] = i;
- break;
-
- case 8:
- min = _mm256_blend_pd(min, val, 8);
- minidx[3] = i;
- break;
-
- case 9:
- min = _mm256_blend_pd(min, val, 9);
- minidx[3] = i;
- minidx[0] = i;
- break;
-
- case 10:
- min = _mm256_blend_pd(min, val, 10);
- minidx[3] = i;
- minidx[1] = i;
- break;
-
- case 11:
- min = _mm256_blend_pd(min, val, 11);
- minidx[3] = i;
- minidx[1] = i;
- minidx[0] = i;
- break;
-
- case 12:
- min = _mm256_blend_pd(min, val, 12);
- minidx[3] = i;
- minidx[2] = i;
- break;
-
- case 13:
- min = _mm256_blend_pd(min, val, 13);
- minidx[3] = i;
- minidx[2] = i;
- minidx[0] = i;
- break;
-
- case 14:
- min = _mm256_blend_pd(min, val, 14);
- minidx[3] = i;
- minidx[2] = i;
- minidx[1] = i;
- break;
-
- case 15:
- min = _mm256_blend_pd(min, val, 15);
- minidx[3] = i;
- minidx[2] = i;
- minidx[1] = i;
- minidx[0] = i;
- break;
-
- case 0:
- default: ;
- }
- }
- minidx[0] = (x[minidx[0]+0] < x[minidx[1]+1]) ? minidx[0]+0 : minidx[1]+1;
- minidx[1] = (x[minidx[2]+2] < x[minidx[3]+3]) ? minidx[2]+2 : minidx[3]+3;
- minidx[0] = (x[minidx[0]] < x[minidx[1]]) ? minidx[0] : minidx[1];
-
- return minidx[0];
-}
-
-
-int
-argmin_kernel8(int n, double *x)
-{
-#define SET(d) idx[d] = d, min[d] = x[d]
-#define PUT(d) if (x[i+d] < min[d]) idx[d] = i+d, min[d] = x[i+d]
- int i, idx[8];
- double min[8];
-
- SET(0); SET(1); SET(2); SET(3);
- SET(4); SET(5); SET(6); SET(7);
-
- for (i = 0; i < n; i += 8) {
- PUT(0); PUT(1); PUT(2); PUT(3);
- PUT(4); PUT(5); PUT(6); PUT(7);
- }
-
- n = 0;
- for (i = 1; i < 8; i++) {
- if (min[i] < min[n]) {
- n = i;
+ max[0] = x[0]; max[1] = x[1]; max[2] = x[2]; max[3] = x[3];
+ max[4] = x[4]; max[5] = x[5]; max[6] = x[6]; max[7] = x[7];
+ for (i = 0; i < len; i+=8) {
+ if (x[i+0] > max[0]) {
+ max[0] = *x;
+ ix[0] = i;
}
- }
- return idx[n];
-#undef PUT
-#undef SET
-}
-
-int
-blas·argmin(int len, double *x, int inc)
-{
- double min;
- int i, ix, n;
-
- if (len == 0) {
- return -1;
- }
-
- if (inc == 1) {
- n = EVEN_BY(len, 8);
- ix = argmin_kernel8_avx2(n, x);
- min = x[ix];
- x += n;
- } else {
- n = EVEN_BY(len, 4);
- ix = 0;
- min = x[ix];
- for (i = 0; i < n; i += 4, x += 4*inc) {
- if (x[0*inc] < min) {
- ix = i;
- min = x[0*inc];
- }
- if (x[1*inc] < min) {
- ix = i+1;
- min = x[1*inc];
- }
- if (x[2*inc] < min) {
- ix = i+2;
- min = x[2*inc];
- }
- if (x[3*inc] < min) {
- ix = i+3;
- min = x[3*inc];
- }
+ if (x[i+1] > max[1]) {
+ max[1] = *x;
+ ix[1] = i;
}
- }
- for (; n < len; n++, x += inc) {
- if (*x < min) {
- ix = n;
- min = *x;
+ if (x[i+2] > max[2]) {
+ max[2] = *x;
+ ix[2] = i;
}
-
- }
-
- return ix;
-}
-
-// -----------------------------------------------------------------------
-// level two
-
-/*
- * Notation: (number of rows) x (number of columns) _ unroll factor
- * N => variable we sum over
- */
-
-
-// NOTE: All triangular matrix methods are assumed packed and upper for now!
-
-/*
- * triangular shaped transformation
- * x = Mx
- * @M: square triangular
- * TODO(PERF): Diagnose speed issues
- * TODO: Finish all other flag cases!
- */
-void
-blas·tpmv(blas·Flag f, int n, double *m, double *x)
-{
- int i;
- for (i = 0; i < n; m += (n-i), ++x, ++i) {
- *x = blas·dot(n-i, m, 1, x, 1);
- }
-}
-
-/*
- * solve triangular set of equations
- * x = M^{-1}b
- * @M: square triangular
- * TODO(PERF): Diagnose speed issues
- * TODO: Finish all other flag cases!
- */
-void
-blas·tpsv(blas·Flag f, int n, double *m, double *x)
-{
- int i;
- double r;
-
- x += (n - 1);
- m += ((n * (n+1))/2 - 1);
- for (i = n-1; i >= 0; --i, --x, m -= (n-i)) {
- r = blas·dot(n-i-1, m+1, 1, x+1, 1);
- *x = (*x - r) / *m;
- }
-}
-
-/*
- * general affine transformation
- * y = aMx + by
- */
-
-static
-void
-gemv_8xN_kernel4_avx2(int ncol, double *row[8], double *x, double *y)
-{
- int c;
- __m128d hr;
- __m256d x256, r256[8];
-
- for (c = 0; c < 8; c++) {
- r256[c] = _mm256_setzero_pd();
- }
-
- for (c = 0; c < ncol; c += 4) {
- x256 = _mm256_loadu_pd(x+c);
- r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
- r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
- r256[2] += x256 * _mm256_loadu_pd(row[2] + c);
- r256[3] += x256 * _mm256_loadu_pd(row[3] + c);
- r256[4] += x256 * _mm256_loadu_pd(row[4] + c);
- r256[5] += x256 * _mm256_loadu_pd(row[5] + c);
- r256[6] += x256 * _mm256_loadu_pd(row[6] + c);
- r256[7] += x256 * _mm256_loadu_pd(row[7] + c);
- }
-
- y[0] = hsum_avx2(r256[0]);
- y[1] = hsum_avx2(r256[1]);
- y[2] = hsum_avx2(r256[2]);
- y[3] = hsum_avx2(r256[3]);
- y[4] = hsum_avx2(r256[4]);
- y[5] = hsum_avx2(r256[5]);
- y[6] = hsum_avx2(r256[6]);
- y[7] = hsum_avx2(r256[7]);
-}
-
-static
-void
-gemv_4xN_kernel4_avx2(int ncol, double *row[4], double *x, double *y)
-{
- int c;
- __m128d hr;
- __m256d x256, r256[4];
-
- for (c = 0; c < 4; c++) {
- r256[c] = _mm256_setzero_pd();
- }
-
- for (c = 0; c < ncol; c += 4) {
- x256 = _mm256_loadu_pd(x+c);
- r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
- r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
- r256[2] += x256 * _mm256_loadu_pd(row[2] + c);
- r256[3] += x256 * _mm256_loadu_pd(row[3] + c);
- }
-
- y[0] = hsum_avx2(r256[0]);
- y[1] = hsum_avx2(r256[1]);
- y[2] = hsum_avx2(r256[2]);
- y[3] = hsum_avx2(r256[3]);
-}
-
-static
-void
-gemv_4xN_kernel4(int n, double *row[4], double *x, double *y)
-{
- int c;
- double res[4];
-
- res[0] = 0.0;
- res[1] = 0.0;
- res[2] = 0.0;
- res[3] = 0.0;
-
- for (c = 0; c < n; c += 4) {
- res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3];
- res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3];
- res[2] += row[2][c+0]*x[c+0] + row[2][c+1]*x[c+1] + row[2][c+2]*x[c+2] + row[2][c+3]*x[c+3];
- res[3] += row[3][c+0]*x[c+0] + row[3][c+1]*x[c+1] + row[3][c+2]*x[c+2] + row[3][c+3]*x[c+3];
- }
-
- y[0] = res[0];
- y[1] = res[1];
- y[2] = res[2];
- y[3] = res[3];
-}
-
-static
-void
-gemv_2xN_kernel4_avx2(int n, double *row[2], double *x, double *y)
-{
- int c;
- __m128d hr;
- __m256d x256, r256[2];
-
- for (c = 0; c < 2; c++) {
- r256[c] = _mm256_setzero_pd();
- }
-
- for (c = 0; c < n; c += 4) {
- x256 = _mm256_loadu_pd(x+c);
- r256[0] += x256 * _mm256_loadu_pd(row[0] + c);
- r256[1] += x256 * _mm256_loadu_pd(row[1] + c);
- }
-
- y[0] = hsum_avx2(r256[0]);
- y[1] = hsum_avx2(r256[1]);
-}
-
-static
-void
-gemv_2xN_kernel4(int n, double *row[2], double *x, double *y)
-{
- int c;
- double res[2];
-
- res[0] = 0.0;
- res[1] = 0.0;
-
- for (c = 0; c < n; c += 4) {
- res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3];
- res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3];
- }
-
- y[0] = res[0];
- y[1] = res[1];
-}
-
-static
-void
-gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y)
-{
- int c;
- __m128d r128;
- __m256d r256;
-
- r256 = _mm256_setzero_pd();
- for (c = 0; c < n; c += 4) {
- r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c);
- }
-
- r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1));
- r128 = _mm_hadd_pd(r128, r128);
-
- *y = r128[0];
- *y = hsum_avx2(r256);
-}
-
-static
-void
-gemv_1xN_kernel4(int n, double *row, double *x, double *y)
-{
- int c;
- double res;
-
- res = 0.;
- for (c = 0; c < n; c += 4) {
- res += row[c+0]*x[c+0] + row[c+1]*x[c+1] + row[c+2]*x[c+2] + row[c+3]*x[c+3];
- }
-
- *y = res;
-}
-
-error
-blas·gemv(int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy)
-{
- int c, r, nr, nc;
- double *row[8], res[8];
- enum {
- err·nil,
- err·incm,
- };
-
- if (incm < ncol) {
- errorf("aliased matrix: inc = %d < ncols = %d", incm, ncol);
- return err·incm;
- }
-
- if (incx == 1 && incy == 1) {
- nc = EVEN_BY(ncol, 4);
-
- nr = EVEN_BY(nrow, 8);
- for (r = 0; r < nr; r += 8) {
- row[0] = m + ((r+0) * incm);
- row[1] = m + ((r+1) * incm);
- row[2] = m + ((r+2) * incm);
- row[3] = m + ((r+3) * incm);
- row[4] = m + ((r+4) * incm);
- row[5] = m + ((r+5) * incm);
- row[6] = m + ((r+6) * incm);
- row[7] = m + ((r+7) * incm);
-
- gemv_8xN_kernel4_avx2(nc, row, x, res);
-
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
- res[1] += row[1][c]*x[c];
- res[2] += row[2][c]*x[c];
- res[3] += row[3][c]*x[c];
- res[4] += row[4][c]*x[c];
- res[5] += row[5][c]*x[c];
- res[6] += row[6][c]*x[c];
- res[7] += row[7][c]*x[c];
- }
-
- y[r+0] = a*res[0] + b*y[r+0];
- y[r+1] = a*res[1] + b*y[r+1];
- y[r+2] = a*res[2] + b*y[r+2];
- y[r+3] = a*res[3] + b*y[r+3];
- y[r+4] = a*res[4] + b*y[r+4];
- y[r+5] = a*res[5] + b*y[r+5];
- y[r+6] = a*res[6] + b*y[r+6];
- y[r+7] = a*res[7] + b*y[r+7];
+ if (x[i+3] > max[3]) {
+ max[3] = *x;
+ ix[3] = i;
}
-
- nr = EVEN_BY(nrow, 4);
- for (; r < nr; r += 4) {
- row[0] = m + ((r+0) * incm);
- row[1] = m + ((r+1) * incm);
- row[2] = m + ((r+2) * incm);
- row[3] = m + ((r+3) * incm);
-
- gemv_4xN_kernel4_avx2(nc, row, x, res);
-
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
- res[1] += row[1][c]*x[c];
- res[2] += row[2][c]*x[c];
- res[3] += row[3][c]*x[c];
- }
-
- y[r+0] = a*res[0] + b*y[r+0];
- y[r+1] = a*res[1] + b*y[r+1];
- y[r+2] = a*res[2] + b*y[r+2];
- y[r+3] = a*res[3] + b*y[r+3];
+ if (x[i+4] > max[4]) {
+ max[4] = *x;
+ ix[4] = i;
}
-
- nr = EVEN_BY(nrow, 2);
- for (; r < nr; r += 2) {
- row[0] = m + ((r+0) * incm);
- row[1] = m + ((r+1) * incm);
- gemv_2xN_kernel4_avx2(nc, row, x, res);
-
- for (c = nc; c < ncol; c++) {
- res[0] += row[0][c]*x[c];
- res[1] += row[1][c]*x[c];
- }
-
- y[r+0] = a*res[0] + b*y[r+0];
- y[r+1] = a*res[1] + b*y[r+1];
+ if (x[i+5] > max[5]) {
+ max[5] = *x;
+ ix[5] = i;
}
-
- for (; r < nrow; r++) {
- row[0] = m + ((r+0) * ncol);
- res[0] = blas·dot(ncol, row[0], 1, x, 1);
- y[r] = a*res[0] + b*y[r];
+ if (x[i+6] > max[6]) {
+ max[6] = *x;
+ ix[6] = i;
}
- }
-
- return 0;
-}
-
-/*
- * rank one addition
- * M = ax(y^T) + M
- * TODO: vectorize kernel
- */
-
-void
-blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m)
-{
- int i, j;
-
- for (i = 0; i < nrow; i++) {
- for (j = 0; j < ncol; j++) {
- m[i+ncol*j] += a * x[i] * y[j];
+ if (x[i+7] > max[7]) {
+ max[7] = *x;
+ ix[7] = i;
}
}
-}
-
-/*
- * rank one addition
- * M = ax(x^T) + M
- */
-
-void
-blas·her(int n, double a, double *x, double *m)
-{
- blas·ger(n, n, a, x, x, m);
-}
-
-/*
- * symmetric rank one addition
- * M = ax(x^T) + M
- * TODO: vectorize kernel
- */
-
-void
-blas·syr(int nrow, int ncol, double a, double *x, double *m)
-{
- int i, j;
-
- for (i = 0; i < nrow; i++) {
- for (j = 0; j < ncol; j++) {
- m[i+ncol*j] += a * x[i] * x[j];
- }
- }
-}
-
-
-// -----------------------------------------------------------------------
-// level three
-
-/*
- * matrix multiplication
- * m3 = a(m1 * m2) + b(m3)
- * einstein notation:
- * m3_{ij} = a m1_{ik} m2_{kj} + b m3_{ij}
- *
- * n1 = # rows of m1 = # rows of m3
- * n2 = # cols of m2 = # cols of m3
- * n3 = # cols of m1 = # rows of m2
- *
- * TODO: Right now we are 2x slower than OpenBLAS.
- * This is because we use a very simple algorithm.
- */
-
-void
-blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3)
-{
- int i, j, k, len;
- // TODO: Is there anyway this computation can be integrated into the one below?
- for (i = 0; i < n1; i++) {
- for (j = 0; j < n2; j++) {
- m3[i + n2*j] *= b;
+ for (i = 1; i < 8; i++) {
+ if (max[i] > max[0]) {
+ max[0] = max[i];
+ ix[0] = ix[i] + i;
}
}
- len = n1 & ~7;
- for (j = 0; j < n2; j++) {
- for (k = 0; k < n3; k++) {
- axpy_kernel8_avx2(len, a * m2[k + n2*j], m1 + n3*k, m3 + n2*j);
-
- // remainder
- for (i = len; i < n1; i++) {
- m3[i + n2*j] += a * m1[i + n3*k] * m2[k + n2*j];
- }
- }
- }
+ return ix[0];
}
-/*
- * triangular matrix multiplication
- * m2 = a * m1 * m2 _OR_ a * m2 * m1
- * m1 is assumed triangular
- * einstein notation:
- * m2_{ij} = a m1_{ik} m2_{kj} _OR_ a m1_{kj} m2_{ik}
- *
- * nrow = # rows of m2
- * ncol = # cols of m2
- * TODO(PERF): make compute kernel
- * TODO: finish all other flags
- */
-void
-blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
+int
+argmax(int len, double *x)
{
- int i, j, k, len;
+ int i, ix;
+ double *end;
+ double max;
- for (i = 0; i < nrow; i++) {
- for (j = 0; j < ncol; j++) {
- for (k = i; k < ncol; k++) {
- m2[i + ncol*j] += a * m1[i + nrow*k] * m2[k + ncol*j];
- }
+ max = *x;
+ for (i = 0; i < len; i++) {
+ if (x[i] > max) {
+ max = *x;
+ ix = i;
}
}
-}
-/*
- * solve triangular matrix system of equations
- * m2 = a * m1^{-1L} _OR_ a * m2 * m1
- * m1 is assumed triangular
- *
- * nrow = # rows of m2
- * ncol = # cols of m2
- * TODO: complete stub
- * TODO: finish all other flags
- */
-void
-blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
-{
+ return ix;
}
-#define NITER 1000
-#define NCOL 2005
-#define NROW 2005
-
+#define LEN 1000000
+#define NIT 2000
error
-test·level3()
-{
- vlong i, n;
- clock_t t;
-
- double *x, *y, *m[3];
- double res[2], tprof[2];
-
- // openblas_set_num_threads(1);
-
- x = malloc(sizeof(*x)*NCOL);
- y = malloc(sizeof(*x)*NCOL);
- m[0] = malloc(sizeof(*x)*NCOL*NROW);
- m[1] = malloc(sizeof(*x)*NCOL*NROW);
- m[2] = malloc(sizeof(*x)*NCOL*NROW);
-
- tprof[0] = 0;
- tprof[1] = 0;
-
- for (n = 0; n < NITER; n++) {
- for (i = 0; i < NCOL; i++) {
- x[i] = i*i+1;
- y[i] = i+1;
- }
-
- for (i = 0; i < NCOL*NROW; i++) {
- m[0][i] = i/(NCOL*NROW);
- m[1][i] = i*i/(NCOL*NROW*NCOL*NROW);
- m[2][i] = 1;
- }
-
- t = clock();
- cblas_dgemm(CblasRowMajor, CblasNoTrans, CblasNoTrans, NCOL, NROW, NROW, 1.2, m[0], NROW, m[1], NROW, 2.8, m[2], NROW);
- t = clock() - t;
- tprof[1] += 1000.*t/CLOCKS_PER_SEC;
- res[1] = cblas_ddot(NROW*NCOL, m[2], 1, m[2], 1);
-
- for (i = 0; i < NCOL; i++) {
- x[i] = i*i+1;
- y[i] = i+1;
- }
-
- for (i = 0; i < NCOL*NROW; i++) {
- m[0][i] = i/(NCOL*NROW);
- m[1][i] = i*i/(NCOL*NROW*NCOL*NROW);
- m[2][i] = 1;
- }
-
- t = clock();
- blas·gemm(NROW, NROW, NROW, 1.2, m[0], m[1], 2.8, m[2]);
- t = clock() - t;
- tprof[0] += 1000.*t/CLOCKS_PER_SEC;
- res[0] = blas·dot(NROW*NCOL, m[2], 1, m[2], 1);
- }
- printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
- printf("--> result (naive): %f\n", res[0]);
- printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
- printf("--> result (oblas): %f\n", res[1]);
-
- return 0;
-}
-
-void
-test·level2()
-{
- int i, j, n, it;
- clock_t t;
-
- double *x, *y, *z, *m;
- double tprof[2];
-
- rng·init(0);
-
- tprof[0] = 0;
- tprof[1] = 0;
-
- x = malloc(sizeof(*x)*NCOL);
- y = malloc(sizeof(*x)*NCOL);
- z = malloc(sizeof(*x)*NCOL);
- m = malloc(sizeof(*x)*NROW*NCOL);
-
- for (it = 0; it < NITER; it++) {
- n = 0;
- for (i = 0; i < NROW; i++) {
- x[i] = rng·random();
- y[i] = rng·random();
- for (j = 0; j < NCOL; j++) {
- m[n++] = rng·random() + .1;
- }
- }
-
- memcpy(z, y, NCOL * sizeof(*y));
-
- t = clock();
- blas·gemv(NROW, NCOL, 2, m, NCOL, x, 1, 1.0, y, 1);
- t = clock() - t;
-
- tprof[0] += 1000.*t/CLOCKS_PER_SEC;
-
- t = clock();
- cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 1.0, z, 1);
- t = clock() - t;
-
- tprof[1] += 1000.*t/CLOCKS_PER_SEC;
-
- for (i = 0; i < NCOL; i++) {
- if (math·abs(z[i] - y[i])/math·abs(x[i]) > 1e-5) {
- errorf("failure at index %d: %f != %f", i, z[i], y[i]);
- }
- }
- }
-
- printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
- printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
-}
-
-void
-print·array(int n, double *x)
+main()
{
- double *end;
- printf("[");
- for (end=x+n; x != end; ++x) {
- printf("%f,", *x);
- }
- printf("]\n");
-}
+ int i, nit;
+ double *x, *y, res[3];
-error
-test·level1()
-{
- int ai, ai2, i, n;
- double *x, *y;
- double tprof[2];
- // double params[5];
clock_t t;
+ double tprof[3] = { 0 };
- x = malloc(sizeof(*x)*NCOL);
- y = malloc(sizeof(*x)*NCOL);
rng·init(0);
- // params[0] = -1.;
- // params[1] = 100; params[2] = 20; params[3] = 30; params[4] = 10;
-
- for (n = 0; n < NITER; n++) {
- for (i = 0; i < NCOL; i++) {
- y[i] = rng·random();
- }
- memcpy(x, y, sizeof(*x)*NCOL);
-
- t = clock();
- ai = blas·argmin(NCOL, x, 1);
- t = clock() - t;
- tprof[0] += 1000.*t/CLOCKS_PER_SEC;
+ x = malloc(sizeof(*x)*LEN);
+ // y = malloc(sizeof(*x)*LEN);
- if (n == 20729) {
- printf("[%d]=%f vs [%d]=%f\n", 74202, x[74202], 3, x[3]);
- }
- memcpy(x, y, sizeof(*x)*NCOL);
-
- t = clock();
- ai2 = cblas_idamin(NCOL, x, 1);
- t = clock() - t;
- tprof[1] += 1000.*t/CLOCKS_PER_SEC;
-
- if (ai != ai2) {
- printf("iteration %d: %d not equal to %d. %f vs %f\n", n, ai, ai2, x[ai], x[ai2]);
- }
- }
-
- printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
- printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
-
- double a, b, c, s;
-
- a = 10.234, b = 2.;
- cblas_drotg(&a, &b, &c, &s);
- printf("%f, %f, %f, %f\n", a, b, c, s);
-
- a = 10.234, b = 2.;
- blas·rotg(&a, &b, &c, &s);
- printf("%f, %f, %f, %f\n", a, b, c, s);
-
- return 0;
-}
+#define DO_0 t = clock(); \
+ res[0] += argmax(LEN, x); \
+ t = clock() - t; \
+ tprof[0] += 1000.*t/CLOCKS_PER_SEC; \
-#define STEP 1
-error
-test·argmax()
-{
- int i, n;
- double *x, *y, *w, *z;
- double res[2], tprof[2];
- int idx[2];
- clock_t t;
+#define DO_1 t = clock(); \
+ res[1] += argmax2(LEN, x); \
+ t = clock() - t; \
+ tprof[1] += 1000.*t/CLOCKS_PER_SEC; \
- x = malloc(sizeof(*x)*NCOL);
- y = malloc(sizeof(*x)*NCOL);
- w = malloc(sizeof(*x)*NCOL);
- z = malloc(sizeof(*x)*NCOL);
- rng·init(0);
+#define DO_2 t = clock(); \
+ res[2] += cblas_idamax(LEN, x, 1); \
+ t = clock() - t; \
+ tprof[2] += 1000.*t/CLOCKS_PER_SEC;
- for (n = 0; n < NITER; n++) {
- for (i = 0; i < NCOL; i++) {
+ for (nit = 0; nit < NIT; nit++) {
+ for (i = 0; i < LEN; i++) {
x[i] = rng·random();
- y[i] = rng·random();
-
- w[i] = x[i];
- z[i] = y[i];
+ // y[i] = rng·random();
}
- t = clock();
- idx[0] = cblas_idamin(NCOL/STEP, w, STEP);
- t = clock() - t;
- tprof[1] += 1000.*t/CLOCKS_PER_SEC;
-
- t = clock();
- idx[1] = blas·argmin(NCOL/STEP, x, STEP);
- t = clock() - t;
- tprof[0] += 1000.*t/CLOCKS_PER_SEC;
-
- if (idx[0] != idx[1]) {
- errorf("%d != %d", idx[0], idx[1]);
+ switch (nit % 6) {
+ case 0: DO_0; DO_1; DO_2; break;
+ case 1: DO_0; DO_2; DO_1; break;
+ case 2: DO_1; DO_0; DO_2; break;
+ case 3: DO_1; DO_2; DO_0; break;
+ case 4: DO_2; DO_0; DO_1; break;
+ case 5: DO_2; DO_1; DO_0; break;
}
- // if (math·abs(res[0] - res[1])/math·abs(res[0]) > 1e-4) {
- // errorf("%f != %f", res[0], res[1]);
- // }
- // for (i = 0; i < NCOL; i++) {
- // if (math·abs(x[i] - w[i]) + math·abs(y[i] - z[i]) > 1e-4) {
- // errorf("%f != %f & %f != %f at index %d", x[i], w[i], y[i], z[i], i);
- // }
- // }
}
+ printf("mean time/iteration (naive): %fms\n", tprof[0]/NIT);
+ printf("--> result (naive): %f\n", res[0]);
+ printf("mean time/iteration (unrolled): %fms\n", tprof[1]/NIT);
+ printf("--> result (unrolled): %f\n", res[1]);
+ printf("mean time/iteration (openblas): %fms\n", tprof[2]/NIT);
+ printf("--> result (openblas): %f\n", res[2]);
return 0;
}
-
-error
-main()
-{
- test·level2();
- return 0;
-}