aboutsummaryrefslogtreecommitdiff
path: root/sys
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-09 13:22:51 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-09 13:22:51 -0700
commit672079795d607270638103dd93fa453645e5a38a (patch)
tree26b6ef732a91b29098b7bccaa7f1289af081f9ad /sys
parent3290956dc738abf4998adb94a96a3eff487fae3c (diff)
feat: all level 1 functions are now strided
Diffstat (limited to 'sys')
-rw-r--r--sys/libmath/blas.c310
-rw-r--r--sys/libmath/linalg.c14
2 files changed, 225 insertions, 99 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index c419c6f..1e1e1c7 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -6,6 +6,8 @@
#include <x86intrin.h>
#include <time.h>
+#define EVEN_BY(x, n) (x) & ~((n)-1)
+
// -----------------------------------------------------------------------
// level one
@@ -219,10 +221,10 @@ blas·rotm(int len, double *x, int incx, double *y, int incy, double p[5])
} else {
n = EVEN_BY(len, 4);
for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
- tmp = x[0*incx], x[0*incx] = H[0]*x[0*incx] + H[1]*y[0*incy], y[0*incy] = H[2]*y[0*incy] + H[3]*tmp;
- tmp = x[1*incx], x[1*incx] = H[0]*x[1*incx] + H[1]*y[1*incy], y[1*incy] = H[2]*y[1*incy] + H[3]*tmp;
- tmp = x[2*incx], x[2*incx] = H[0]*x[2*incx] + H[1]*y[2*incy], y[2*incy] = H[2]*y[2*incy] + H[3]*tmp;
- tmp = x[3*incx], x[3*incx] = H[0]*x[3*incx] + H[1]*y[3*incy], y[3*incy] = H[2]*y[3*incy] + H[3]*tmp;
+ tmp = x[0*incx], x[0*incx] = H[0]*x[0*incx] + H[1]*y[0*incy], y[0*incy] = H[2]*tmp + H[3]*y[0*incy];
+ tmp = x[1*incx], x[1*incx] = H[0]*x[1*incx] + H[1]*y[1*incy], y[1*incy] = H[2]*tmp + H[3]*y[1*incy];
+ tmp = x[2*incx], x[2*incx] = H[0]*x[2*incx] + H[1]*y[2*incy], y[2*incy] = H[2]*tmp + H[3]*y[2*incy];
+ tmp = x[3*incx], x[3*incx] = H[0]*x[3*incx] + H[1]*y[3*incy], y[3*incy] = H[2]*tmp + H[3]*y[3*incy];
}
}
@@ -273,15 +275,25 @@ scale_kernel8(int n, double *x, double a)
}
void
-blas·scale(int len, double a, double *x)
+blas·scale(int len, double a, double *x, int inc)
{
- int n;
-
- n = len & ~7;
- scale_kernel8_avx2(n, x, a);
+ int n, ix;
- for (; n < len; n++) {
- x[n] *= a;
+ if (inc == 1) {
+ n = EVEN_BY(len, 8);
+ scale_kernel8_avx2(n, x, a);
+ ix = n;
+ } else {
+ n = EVEN_BY(len, 4);
+ for (ix = 0; ix < n*inc; ix += 4*inc) {
+ x[ix+0*inc] *= a;
+ x[ix+1*inc] *= a;
+ x[ix+2*inc] *= a;
+ x[ix+3*inc] *= a;
+ }
+ }
+ for (; n < len; n++, ix += inc) {
+ x[ix] *= a;
}
}
@@ -291,9 +303,25 @@ blas·scale(int len, double a, double *x)
*/
void
-blas·copy(int len, double *x, double *y)
+blas·copy(int len, double *x, int incx, double *y, int incy)
{
- memcpy(y, x, sizeof(*x) * len);
+ int n, i, ix, iy;
+ if (incx == 1 && incy == 1) {
+ memcpy(y, x, sizeof(*x) * len);
+ return;
+ }
+
+ n = EVEN_BY(len, 4);
+ for (i = 0, incx = 0, incy = 0; i < n; i+=4, ix+=4*incx, iy+=4*incy) {
+ y[iy+0*incy] = x[ix+0*incx];
+ y[iy+1*incy] = x[ix+1*incx];
+ y[iy+2*incy] = x[ix+2*incx];
+ y[iy+3*incy] = x[ix+3*incx];
+ }
+
+ for (; n < len; n++, ix+=incx, iy+=incy) {
+ y[iy] = x[ix];
+ }
}
/*
@@ -339,15 +367,28 @@ swap_kernel8(int n, double *x, double *y)
}
void
-blas·swap(int len, double *x, double *y)
+blas·swap(int len, double *x, int incx, double *y, int incy)
{
- int n;
+ int n, i, ix, iy;
double tmp;
- n = len & ~7;
- swap_kernel8(n, x, y);
- for (; n < len; n++) {
- tmp = x[n], x[n] = y[n], y[n] = tmp;
+ if (incx == 1 && incy == 1) {
+ n = EVEN_BY(len, 8);
+ swap_kernel8(n, x, y);
+ ix = n;
+ iy = n;
+ } else {
+ n = EVEN_BY(len, 4);
+ for (i = 0, ix = 0, iy = 0; i < n; i += 4, ix += 4*incx, iy += 4*incy) {
+ tmp = x[ix + 0*incx], x[ix + 0*incx] = y[iy + 0*incy], y[iy + 0*incy] = tmp;
+ tmp = x[ix + 1*incx], x[ix + 1*incx] = y[iy + 1*incy], y[iy + 1*incy] = tmp;
+ tmp = x[ix + 2*incx], x[ix + 2*incx] = y[iy + 2*incy], y[iy + 2*incy] = tmp;
+ tmp = x[ix + 3*incx], x[ix + 3*incx] = y[iy + 3*incy], y[iy + 3*incy] = tmp;
+ }
+ }
+
+ for (; n < len; n++, ix += incx, iy += incy) {
+ tmp = x[ix], x[ix] = y[iy], y[iy] = tmp;
}
}
@@ -391,15 +432,27 @@ axpy_kernel8(int n, double a, double *x, double *y)
}
void
-blas·axpy(int len, double a, double *x, double *y)
+blas·axpy(int len, double a, double *x, int incx, double *y, int incy)
{
- int n;
+ int n, i;
- n = len & ~7;
- axpy_kernel8_avx2(n, a, x, y);
+ if (incx == 1 && incy == 1) {
+ n = EVEN_BY(len, 8);
+ axpy_kernel8_avx2(n, a, x, y);
+ x += n;
+ y += n;
+ } else {
+ n = EVEN_BY(len, 4);
+ for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
+ y[0*incy] += a*x[0*incx];
+ y[1*incy] += a*x[1*incx];
+ y[2*incy] += a*x[2*incx];
+ y[3*incy] += a*x[3*incx];
+ }
+ }
- for (; n < len; n++) {
- y[n] += a*x[n];
+ for (; n < len; n++, x+=incx, y+=incy) {
+ *y += a*(*x);
}
}
@@ -486,38 +539,35 @@ dot_kernel8(int len, double *x, double *y)
double
blas·dot(int len, double *x, int incx, double *y, int incy)
{
- int i, n, ix, iy;
- double res, mul[4], sum[2];
-
+ int i, n;
+ double mul[4], sum[2];
if (len == 0) return 0;
+ sum[0] = 0, sum[1] = 0;
if (incx == 1 && incy == 1) {
- n = len & ~15; // neat trick
- res = dot_kernel8_fma3(n, x, y);
-
- for (i = n; i < len; i++) {
- res += x[i] * y[i];
+ n = EVEN_BY(len, 16);
+ sum[0] = dot_kernel8_fma3(n, x, y);
+ x += n;
+ y += n;
+ } else {
+ n = EVEN_BY(len, 4);
+ for (i = 0; i < n; i += 4, x += 4*incx, y += 4*incy) {
+ mul[0] = x[0*incx] * y[0*incy];
+ mul[1] = x[1*incx] * y[1*incy];
+ mul[2] = x[2*incx] * y[2*incy];
+ mul[3] = x[3*incx] * y[3*incy];
+
+ sum[0] += mul[0] + mul[2];
+ sum[1] += mul[1] + mul[3];
}
- return res;
- }
-
- n = len & ~3;
- for (i = 0, ix = 0, iy = 0; i < n; i += 4, ix += 4*incx, iy += 4*incy) {
- mul[0] = x[ix+0*incx] * y[iy+0*incy];
- mul[1] = x[ix+1*incx] * y[iy+1*incy];
- mul[2] = x[ix+2*incx] * y[iy+2*incy];
- mul[3] = x[ix+3*incx] * y[iy+3*incy];
-
- sum[0] += mul[0] + mul[2];
- sum[1] += mul[1] + mul[3];
}
- for (; i < len; i++, ix += incx, iy += incy) {
- sum[0] += x[ix] * y[iy];
+ for (; i < len; i++, x += incx, y += incy) {
+ sum[0] += x[0] * y[0];
}
- res = sum[0] + sum[1];
- return res;
+ sum[0] += sum[1];
+ return sum[0];
}
/*
@@ -525,11 +575,11 @@ blas·dot(int len, double *x, int incx, double *y, int incy)
* ||x||
*/
double
-blas·norm(int len, double *x)
+blas·norm(int len, double *x, int incx)
{
double res;
- res = blas·dot(len, x, 1, x, 1);
+ res = blas·dot(len, x, incx, x, incx);
res = math·sqrt(res);
return res;
@@ -587,18 +637,28 @@ sum_kernel8(int len, double *x, double *y)
* sum(x_i)
*/
double
-blas·sum(int len, double *x)
+blas·sum(int len, double *x, int inc)
{
int i, n;
double res;
if (len == 0) return 0;
- n = len & ~7;
- res = sum_kernel8_avx2(n, x);
+ if (inc == 1) {
+ n = EVEN_BY(len, 8);
+ res = sum_kernel8_avx2(n, x);
+ } else {
+ n = EVEN_BY(len, 4);
+ for (i = 0; i < n; i++, x += 4*inc) {
+ res += x[0*inc];
+ res += x[1*inc];
+ res += x[2*inc];
+ res += x[3*inc];
+ }
+ }
- for (i = n; i < len; i++) {
- res += x[i];
+ for (i = n; i < len; i++, x += inc) {
+ res += x[0];
}
return res;
@@ -754,23 +814,52 @@ argmax_kernel8(int n, double *x)
}
int
-blas·argmax(int len, double *x)
+blas·argmax(int len, double *x, int inc)
{
- int i, n;
+ int i, ix, n;
double max;
- i = len & ~7;
- n = argmax_kernel8_avx2(i, x);
+ if (len == 0) {
+ return -1;
+ }
- max = x[n];
- for (; i < len; i++) {
- if (x[i] > max) {
- n = i;
- max = x[i];
+ if (inc == 1) {
+ n = EVEN_BY(len, 8);
+ ix = argmax_kernel8_avx2(n, x);
+ max = x[ix];
+ x += n;
+ } else {
+ n = EVEN_BY(len, 4);
+ ix = 0;
+ max = x[ix];
+ for (i = 0; i < n; i += 4, x += 4*inc) {
+ if (x[0*inc] > max) {
+ ix = i;
+ max = x[0*inc];
+ }
+ if (x[1*inc] > max) {
+ ix = i+1;
+ max = x[1*inc];
+ }
+ if (x[2*inc] > max) {
+ ix = i+2;
+ max = x[2*inc];
+ }
+ if (x[3*inc] > max) {
+ ix = i+3;
+ max = x[3*inc];
+ }
}
}
- return n;
+ for (; n < len; n++, x += inc) {
+ if (*x > max) {
+ ix = n;
+ max = *x;
+ }
+ }
+
+ return ix;
}
int
@@ -919,23 +1008,52 @@ argmin_kernel8(int n, double *x)
}
int
-blas·argmin(int len, double *x)
+blas·argmin(int len, double *x, int inc)
{
- int i, n;
double min;
+ int i, ix, n;
- i = len & ~7;
- n = argmin_kernel8_avx2(i, x);
+ if (len == 0) {
+ return -1;
+ }
- min = x[n];
- for (; i < len; i++) {
- if (x[i] < min) {
- n = i;
- min = x[i];
+ if (inc == 1) {
+ n = EVEN_BY(len, 8);
+ ix = argmin_kernel8_avx2(n, x);
+ min = x[ix];
+ x += n;
+ } else {
+ n = EVEN_BY(len, 4);
+ ix = 0;
+ min = x[ix];
+ for (i = 0; i < n; i += 4, x += 4*inc) {
+ if (x[0*inc] < min) {
+ ix = i;
+ min = x[0*inc];
+ }
+ if (x[1*inc] < min) {
+ ix = i+1;
+ min = x[1*inc];
+ }
+ if (x[2*inc] < min) {
+ ix = i+2;
+ min = x[2*inc];
+ }
+ if (x[3*inc] < min) {
+ ix = i+3;
+ min = x[3*inc];
+ }
}
}
+ for (; n < len; n++, x += inc) {
+ if (*x < min) {
+ ix = n;
+ min = *x;
+ }
- return n;
+ }
+
+ return ix;
}
// -----------------------------------------------------------------------
@@ -1225,7 +1343,7 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
{
}
-#define NITER 10000
+#define NITER 3000
#define NCOL 5000007
#define NROW 57
@@ -1374,7 +1492,7 @@ test·level1()
memcpy(x, y, sizeof(*x)*NCOL);
t = clock();
- ai = blas·argmin(NCOL, x);
+ ai = blas·argmin(NCOL, x, 1);
t = clock() - t;
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
@@ -1409,49 +1527,57 @@ test·level1()
return 0;
}
+#define STEP 1
error
main()
{
int i, n;
- double *x, *y;
+ double *x, *y, *w, *z;
double res[2], tprof[2];
+ int idx[2];
clock_t t;
x = malloc(sizeof(*x)*NCOL);
y = malloc(sizeof(*x)*NCOL);
+ w = malloc(sizeof(*x)*NCOL);
+ z = malloc(sizeof(*x)*NCOL);
rng·init(0);
for (n = 0; n < NITER; n++) {
for (i = 0; i < NCOL; i++) {
x[i] = rng·random();
y[i] = rng·random();
+
+ w[i] = x[i];
+ z[i] = y[i];
}
t = clock();
- res[1] += cblas_ddot(NCOL/4, x, 4, y, 4);
+ idx[0] = cblas_idamin(NCOL/STEP, w, STEP);
t = clock() - t;
tprof[1] += 1000.*t/CLOCKS_PER_SEC;
t = clock();
- res[0] += blas·dot(NCOL/4, x, 4, y, 4);
+ idx[1] = blas·argmin(NCOL/STEP, x, STEP);
t = clock() - t;
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
+
+ if (idx[0] != idx[1]) {
+ errorf("%d != %d", idx[0], idx[1]);
+ }
+ // if (math·abs(res[0] - res[1])/math·abs(res[0]) > 1e-4) {
+ // errorf("%f != %f", res[0], res[1]);
+ // }
+ // for (i = 0; i < NCOL; i++) {
+ // if (math·abs(x[i] - w[i]) + math·abs(y[i] - z[i]) > 1e-4) {
+ // errorf("%f != %f & %f != %f at index %d", x[i], w[i], y[i], z[i], i);
+ // }
+ // }
}
printf("%f, %f\n", res[0], res[1]);
printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
- double a, b, c, s;
-
- a = 10.234, b = 2.;
- cblas_drotg(&a, &b, &c, &s);
- printf("%f, %f, %f, %f\n", a, b, c, s);
-
- a = 10.234, b = 2.;
- blas·rotg(&a, &b, &c, &s);
- printf("%f, %f, %f, %f\n", a, b, c, s);
-
return 0;
-
}
diff --git a/sys/libmath/linalg.c b/sys/libmath/linalg.c
index 5a73527..09e100c 100644
--- a/sys/libmath/linalg.c
+++ b/sys/libmath/linalg.c
@@ -10,8 +10,8 @@ linalg·normalize(math·Vector vec)
{
double norm;
- norm = blas·norm(vec.len, vec.data);
- blas·scale(vec.len, 1/norm, vec.data);
+ norm = blas·norm(vec.len, vec.data, 1);
+ blas·scale(vec.len, 1/norm, vec.data, 1);
}
// TODO: Write blas wrappers that eat vectors for convenience
@@ -50,12 +50,12 @@ linalg·lq(math·Matrix m, math·Vector w)
len = m.dim[0] - i;
// TODO: Don't want to compute norm twice!!
- w.data[0] = math·sgn(row[0]) * blas·norm(len, row);
- blas·axpy(len, 1.0, row, w.data);
- mag = blas·norm(len, w.data);
- blas·scale(len, 1/mag, w.data);
+ w.data[0] = math·sgn(row[0]) * blas·norm(len, row, 1);
+ blas·axpy(len, 1.0, row, 1, w.data, 1);
+ mag = blas·norm(len, w.data, 1);
+ blas·scale(len, 1/mag, w.data, 1);
- blas·copy(len - m.dim[0], w.data, m.data + i);
+ blas·copy(len - m.dim[0], w.data, 1, m.data + i, 1);
}
return err·nil;