From 672079795d607270638103dd93fa453645e5a38a Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Sat, 9 May 2020 13:22:51 -0700 Subject: feat: all level 1 functions are now strided --- sys/libmath/blas.c | 310 +++++++++++++++++++++++++++++++++++++---------------- 1 file changed, 218 insertions(+), 92 deletions(-) (limited to 'sys/libmath/blas.c') diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index c419c6f..1e1e1c7 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -6,6 +6,8 @@ #include #include +#define EVEN_BY(x, n) (x) & ~((n)-1) + // ----------------------------------------------------------------------- // level one @@ -219,10 +221,10 @@ blas·rotm(int len, double *x, int incx, double *y, int incy, double p[5]) } else { n = EVEN_BY(len, 4); for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) { - tmp = x[0*incx], x[0*incx] = H[0]*x[0*incx] + H[1]*y[0*incy], y[0*incy] = H[2]*y[0*incy] + H[3]*tmp; - tmp = x[1*incx], x[1*incx] = H[0]*x[1*incx] + H[1]*y[1*incy], y[1*incy] = H[2]*y[1*incy] + H[3]*tmp; - tmp = x[2*incx], x[2*incx] = H[0]*x[2*incx] + H[1]*y[2*incy], y[2*incy] = H[2]*y[2*incy] + H[3]*tmp; - tmp = x[3*incx], x[3*incx] = H[0]*x[3*incx] + H[1]*y[3*incy], y[3*incy] = H[2]*y[3*incy] + H[3]*tmp; + tmp = x[0*incx], x[0*incx] = H[0]*x[0*incx] + H[1]*y[0*incy], y[0*incy] = H[2]*tmp + H[3]*y[0*incy]; + tmp = x[1*incx], x[1*incx] = H[0]*x[1*incx] + H[1]*y[1*incy], y[1*incy] = H[2]*tmp + H[3]*y[1*incy]; + tmp = x[2*incx], x[2*incx] = H[0]*x[2*incx] + H[1]*y[2*incy], y[2*incy] = H[2]*tmp + H[3]*y[2*incy]; + tmp = x[3*incx], x[3*incx] = H[0]*x[3*incx] + H[1]*y[3*incy], y[3*incy] = H[2]*tmp + H[3]*y[3*incy]; } } @@ -273,15 +275,25 @@ scale_kernel8(int n, double *x, double a) } void -blas·scale(int len, double a, double *x) +blas·scale(int len, double a, double *x, int inc) { - int n; - - n = len & ~7; - scale_kernel8_avx2(n, x, a); + int n, ix; - for (; n < len; n++) { - x[n] *= a; + if (inc == 1) { + n = EVEN_BY(len, 8); + scale_kernel8_avx2(n, x, a); + ix = n; + } else { + n = EVEN_BY(len, 4); + for (ix = 0; ix < n*inc; ix += 4*inc) { + x[ix+0*inc] *= a; + x[ix+1*inc] *= a; + x[ix+2*inc] *= a; + x[ix+3*inc] *= a; + } + } + for (; n < len; n++, ix += inc) { + x[ix] *= a; } } @@ -291,9 +303,25 @@ blas·scale(int len, double a, double *x) */ void -blas·copy(int len, double *x, double *y) +blas·copy(int len, double *x, int incx, double *y, int incy) { - memcpy(y, x, sizeof(*x) * len); + int n, i, ix, iy; + if (incx == 1 && incy == 1) { + memcpy(y, x, sizeof(*x) * len); + return; + } + + n = EVEN_BY(len, 4); + for (i = 0, incx = 0, incy = 0; i < n; i+=4, ix+=4*incx, iy+=4*incy) { + y[iy+0*incy] = x[ix+0*incx]; + y[iy+1*incy] = x[ix+1*incx]; + y[iy+2*incy] = x[ix+2*incx]; + y[iy+3*incy] = x[ix+3*incx]; + } + + for (; n < len; n++, ix+=incx, iy+=incy) { + y[iy] = x[ix]; + } } /* @@ -339,15 +367,28 @@ swap_kernel8(int n, double *x, double *y) } void -blas·swap(int len, double *x, double *y) +blas·swap(int len, double *x, int incx, double *y, int incy) { - int n; + int n, i, ix, iy; double tmp; - n = len & ~7; - swap_kernel8(n, x, y); - for (; n < len; n++) { - tmp = x[n], x[n] = y[n], y[n] = tmp; + if (incx == 1 && incy == 1) { + n = EVEN_BY(len, 8); + swap_kernel8(n, x, y); + ix = n; + iy = n; + } else { + n = EVEN_BY(len, 4); + for (i = 0, ix = 0, iy = 0; i < n; i += 4, ix += 4*incx, iy += 4*incy) { + tmp = x[ix + 0*incx], x[ix + 0*incx] = y[iy + 0*incy], y[iy + 0*incy] = tmp; + tmp = x[ix + 1*incx], x[ix + 1*incx] = y[iy + 1*incy], y[iy + 1*incy] = tmp; + tmp = x[ix + 2*incx], x[ix + 2*incx] = y[iy + 2*incy], y[iy + 2*incy] = tmp; + tmp = x[ix + 3*incx], x[ix + 3*incx] = y[iy + 3*incy], y[iy + 3*incy] = tmp; + } + } + + for (; n < len; n++, ix += incx, iy += incy) { + tmp = x[ix], x[ix] = y[iy], y[iy] = tmp; } } @@ -391,15 +432,27 @@ axpy_kernel8(int n, double a, double *x, double *y) } void -blas·axpy(int len, double a, double *x, double *y) +blas·axpy(int len, double a, double *x, int incx, double *y, int incy) { - int n; + int n, i; - n = len & ~7; - axpy_kernel8_avx2(n, a, x, y); + if (incx == 1 && incy == 1) { + n = EVEN_BY(len, 8); + axpy_kernel8_avx2(n, a, x, y); + x += n; + y += n; + } else { + n = EVEN_BY(len, 4); + for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) { + y[0*incy] += a*x[0*incx]; + y[1*incy] += a*x[1*incx]; + y[2*incy] += a*x[2*incx]; + y[3*incy] += a*x[3*incx]; + } + } - for (; n < len; n++) { - y[n] += a*x[n]; + for (; n < len; n++, x+=incx, y+=incy) { + *y += a*(*x); } } @@ -486,38 +539,35 @@ dot_kernel8(int len, double *x, double *y) double blas·dot(int len, double *x, int incx, double *y, int incy) { - int i, n, ix, iy; - double res, mul[4], sum[2]; - + int i, n; + double mul[4], sum[2]; if (len == 0) return 0; + sum[0] = 0, sum[1] = 0; if (incx == 1 && incy == 1) { - n = len & ~15; // neat trick - res = dot_kernel8_fma3(n, x, y); - - for (i = n; i < len; i++) { - res += x[i] * y[i]; + n = EVEN_BY(len, 16); + sum[0] = dot_kernel8_fma3(n, x, y); + x += n; + y += n; + } else { + n = EVEN_BY(len, 4); + for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) { + mul[0] = x[0*incx] * y[0*incy]; + mul[1] = x[1*incx] * y[1*incy]; + mul[2] = x[2*incx] * y[2*incy]; + mul[3] = x[3*incx] * y[3*incy]; + + sum[0] += mul[0] + mul[2]; + sum[1] += mul[1] + mul[3]; } - return res; - } - - n = len & ~3; - for (i = 0, ix = 0, iy = 0; i < n; i += 4, ix += 4*incx, iy += 4*incy) { - mul[0] = x[ix+0*incx] * y[iy+0*incy]; - mul[1] = x[ix+1*incx] * y[iy+1*incy]; - mul[2] = x[ix+2*incx] * y[iy+2*incy]; - mul[3] = x[ix+3*incx] * y[iy+3*incy]; - - sum[0] += mul[0] + mul[2]; - sum[1] += mul[1] + mul[3]; } - for (; i < len; i++, ix += incx, iy += incy) { - sum[0] += x[ix] * y[iy]; + for (; i < len; i++, x += incx, y += incy) { + sum[0] += x[0] * y[0]; } - res = sum[0] + sum[1]; - return res; + sum[0] += sum[1]; + return sum[0]; } /* @@ -525,11 +575,11 @@ blas·dot(int len, double *x, int incx, double *y, int incy) * ||x|| */ double -blas·norm(int len, double *x) +blas·norm(int len, double *x, int incx) { double res; - res = blas·dot(len, x, 1, x, 1); + res = blas·dot(len, x, incx, x, incx); res = math·sqrt(res); return res; @@ -587,18 +637,28 @@ sum_kernel8(int len, double *x, double *y) * sum(x_i) */ double -blas·sum(int len, double *x) +blas·sum(int len, double *x, int inc) { int i, n; double res; if (len == 0) return 0; - n = len & ~7; - res = sum_kernel8_avx2(n, x); + if (inc == 1) { + n = EVEN_BY(len, 8); + res = sum_kernel8_avx2(n, x); + } else { + n = EVEN_BY(len, 4); + for (i = 0; i < n; i++, x += 4*inc) { + res += x[0*inc]; + res += x[1*inc]; + res += x[2*inc]; + res += x[3*inc]; + } + } - for (i = n; i < len; i++) { - res += x[i]; + for (i = n; i < len; i++, x += inc) { + res += x[0]; } return res; @@ -754,23 +814,52 @@ argmax_kernel8(int n, double *x) } int -blas·argmax(int len, double *x) +blas·argmax(int len, double *x, int inc) { - int i, n; + int i, ix, n; double max; - i = len & ~7; - n = argmax_kernel8_avx2(i, x); + if (len == 0) { + return -1; + } - max = x[n]; - for (; i < len; i++) { - if (x[i] > max) { - n = i; - max = x[i]; + if (inc == 1) { + n = EVEN_BY(len, 8); + ix = argmax_kernel8_avx2(n, x); + max = x[ix]; + x += n; + } else { + n = EVEN_BY(len, 4); + ix = 0; + max = x[ix]; + for (i = 0; i < n; i += 4, x += 4*inc) { + if (x[0*inc] > max) { + ix = i; + max = x[0*inc]; + } + if (x[1*inc] > max) { + ix = i+1; + max = x[1*inc]; + } + if (x[2*inc] > max) { + ix = i+2; + max = x[2*inc]; + } + if (x[3*inc] > max) { + ix = i+3; + max = x[3*inc]; + } } } - return n; + for (; n < len; n++, x += inc) { + if (*x > max) { + ix = n; + max = *x; + } + } + + return ix; } int @@ -919,23 +1008,52 @@ argmin_kernel8(int n, double *x) } int -blas·argmin(int len, double *x) +blas·argmin(int len, double *x, int inc) { - int i, n; double min; + int i, ix, n; - i = len & ~7; - n = argmin_kernel8_avx2(i, x); + if (len == 0) { + return -1; + } - min = x[n]; - for (; i < len; i++) { - if (x[i] < min) { - n = i; - min = x[i]; + if (inc == 1) { + n = EVEN_BY(len, 8); + ix = argmin_kernel8_avx2(n, x); + min = x[ix]; + x += n; + } else { + n = EVEN_BY(len, 4); + ix = 0; + min = x[ix]; + for (i = 0; i < n; i += 4, x += 4*inc) { + if (x[0*inc] < min) { + ix = i; + min = x[0*inc]; + } + if (x[1*inc] < min) { + ix = i+1; + min = x[1*inc]; + } + if (x[2*inc] < min) { + ix = i+2; + min = x[2*inc]; + } + if (x[3*inc] < min) { + ix = i+3; + min = x[3*inc]; + } } } + for (; n < len; n++, x += inc) { + if (*x < min) { + ix = n; + min = *x; + } - return n; + } + + return ix; } // ----------------------------------------------------------------------- @@ -1225,7 +1343,7 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2) { } -#define NITER 10000 +#define NITER 3000 #define NCOL 5000007 #define NROW 57 @@ -1374,7 +1492,7 @@ test·level1() memcpy(x, y, sizeof(*x)*NCOL); t = clock(); - ai = blas·argmin(NCOL, x); + ai = blas·argmin(NCOL, x, 1); t = clock() - t; tprof[0] += 1000.*t/CLOCKS_PER_SEC; @@ -1409,49 +1527,57 @@ test·level1() return 0; } +#define STEP 1 error main() { int i, n; - double *x, *y; + double *x, *y, *w, *z; double res[2], tprof[2]; + int idx[2]; clock_t t; x = malloc(sizeof(*x)*NCOL); y = malloc(sizeof(*x)*NCOL); + w = malloc(sizeof(*x)*NCOL); + z = malloc(sizeof(*x)*NCOL); rng·init(0); for (n = 0; n < NITER; n++) { for (i = 0; i < NCOL; i++) { x[i] = rng·random(); y[i] = rng·random(); + + w[i] = x[i]; + z[i] = y[i]; } t = clock(); - res[1] += cblas_ddot(NCOL/4, x, 4, y, 4); + idx[0] = cblas_idamin(NCOL/STEP, w, STEP); t = clock() - t; tprof[1] += 1000.*t/CLOCKS_PER_SEC; t = clock(); - res[0] += blas·dot(NCOL/4, x, 4, y, 4); + idx[1] = blas·argmin(NCOL/STEP, x, STEP); t = clock() - t; tprof[0] += 1000.*t/CLOCKS_PER_SEC; + + if (idx[0] != idx[1]) { + errorf("%d != %d", idx[0], idx[1]); + } + // if (math·abs(res[0] - res[1])/math·abs(res[0]) > 1e-4) { + // errorf("%f != %f", res[0], res[1]); + // } + // for (i = 0; i < NCOL; i++) { + // if (math·abs(x[i] - w[i]) + math·abs(y[i] - z[i]) > 1e-4) { + // errorf("%f != %f & %f != %f at index %d", x[i], w[i], y[i], z[i], i); + // } + // } } printf("%f, %f\n", res[0], res[1]); printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER); printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER); - double a, b, c, s; - - a = 10.234, b = 2.; - cblas_drotg(&a, &b, &c, &s); - printf("%f, %f, %f, %f\n", a, b, c, s); - - a = 10.234, b = 2.; - blas·rotg(&a, &b, &c, &s); - printf("%f, %f, %f, %f\n", a, b, c, s); - return 0; - } -- cgit v1.2.1