aboutsummaryrefslogtreecommitdiff
path: root/sys/libmath/blas.c
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-08 21:33:24 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-08 21:33:24 -0700
commit04c688f125069b65517b00660c31c81e210ddf3a (patch)
tree5a92bd41181ae1c8c586f7da701c4c1115bd5dd5 /sys/libmath/blas.c
parent327ca20a2a89d2408b53ff7854982560304cb76c (diff)
Adding strided computation to blas kernels.
I started implementing LQ factorization and immediately realized I needed strided views. For simplicity, I will just implement them in the most portable, C native way (no vectorization). Speed can come later.
Diffstat (limited to 'sys/libmath/blas.c')
-rw-r--r--sys/libmath/blas.c228
1 files changed, 189 insertions, 39 deletions
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index 227715c..f1ae09d 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -248,7 +248,7 @@ scale_kernel8(int n, double *x, double a)
}
void
-blas·scale(int len, double *x, double a)
+blas·scale(int len, double a, double *x)
{
int n;
@@ -459,18 +459,121 @@ dot_kernel8(int len, double *x, double *y)
}
double
-blas·dot(int len, double *x, double *y)
+blas·dot(int len, double *x, int incx, double *y, int incy)
+{
+ int i, n, ix, iy;
+ double res, mul[4], sum[2];
+
+ if (len == 0) return 0;
+
+ if (incx == 1 && incy == 1) {
+ n = len & ~15; // neat trick
+ res = dot_kernel8_fma3(n, x, y);
+
+ for (i = n; i < len; i++) {
+ res += x[i] * y[i];
+ }
+ return res;
+ }
+
+ n = len & ~3;
+ for (i = 0, ix = 0, iy = 0; i < n; i += 4, ix += 4*incx, iy += 4*incy) {
+ mul[0] = x[ix+0*incx] * y[iy+0*incy];
+ mul[1] = x[ix+1*incx] * y[iy+1*incy];
+ mul[2] = x[ix+2*incx] * y[iy+2*incy];
+ mul[3] = x[ix+3*incx] * y[iy+3*incy];
+
+ sum[0] += mul[0] + mul[2];
+ sum[1] += mul[1] + mul[3];
+ }
+
+ for (; i < len; i++, ix += incx, iy += incy) {
+ sum[0] += x[ix] * y[iy];
+ }
+
+ res = sum[0] + sum[1];
+ return res;
+}
+
+/*
+ * euclidean norm
+ * ||x||
+ */
+double
+blas·norm(int len, double *x)
+{
+ double res;
+
+ res = blas·dot(len, x, 1, x, 1);
+ res = math·sqrt(res);
+
+ return res;
+}
+
+static
+double
+sum_kernel8_avx2(int len, double *x)
+{
+ register int i;
+ __m256d sum[2];
+ __m128d res;
+
+ for (i = 0; i < arrlen(sum); i++) {
+ sum[i] = _mm256_setzero_pd();
+ }
+
+ for (i = 0; i < len; i += 8) {
+ sum[0] += _mm256_loadu_pd(x+i+0);
+ sum[1] += _mm256_loadu_pd(x+i+4);
+ }
+
+ sum[0] += sum[1];
+
+ res = _mm_add_pd(_mm256_extractf128_pd(sum[0], 0), _mm256_extractf128_pd(sum[0], 1));
+ res = _mm_hadd_pd(res, res);
+
+ return res[0];
+}
+
+static
+double
+sum_kernel8(int len, double *x, double *y)
+{
+ double res;
+ register int i;
+
+ for (i = 0; i < len; i += 8) {
+ res += x[i] +
+ x[i+1] +
+ x[i+2] +
+ x[i+3] +
+ x[i+4] +
+ x[i+5] +
+ x[i+6] +
+ x[i+7];
+ }
+
+ return res;
+}
+
+
+/*
+ * L1 norm
+ * sum(x_i)
+ */
+double
+blas·sum(int len, double *x)
{
int i, n;
double res;
if (len == 0) return 0;
- n = len & ~15; // neat trick
- res = dot_kernel8_fma3(n, x, y);
+ n = len & ~7;
+ res = sum_kernel8_avx2(n, x);
for (i = n; i < len; i++) {
- res += x[i] * y[i];
+ res += x[i];
}
return res;
@@ -833,7 +936,7 @@ blas·tpmv(blas·Flag f, int n, double *m, double *x)
{
int i;
for (i = 0; i < n; m += (n-i), ++x, ++i) {
- *x = blas·dot(n-i, m, x);
+ *x = blas·dot(n-i, m, 1, x, 1);
}
}
@@ -853,7 +956,7 @@ blas·tpsv(blas·Flag f, int n, double *m, double *x)
x += (n - 1);
m += ((n * (n+1))/2 - 1);
for (i = n-1; i >= 0; --i, --x, m -= (n-i)) {
- r = blas·dot(n-i-1, m+1, x+1);
+ r = blas·dot(n-i-1, m+1, 1, x+1, 1);
*x = (*x - r) / *m;
}
}
@@ -1098,13 +1201,13 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
}
#define NITER 10000
-#define NCOL 5007
-#define NROW 5007
+#define NCOL 5000007
+#define NROW 57
error
test·level3()
{
- int i, n;
+ vlong i, n;
clock_t t;
double *x, *y, *m[3];
@@ -1154,7 +1257,7 @@ test·level3()
blas·gemm(NROW, NROW, NROW, 1.2, m[0], m[1], 2.8, m[2]);
t = clock() - t;
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
- res[0] = blas·dot(NROW*NCOL, m[2], m[2]);
+ res[0] = blas·dot(NROW*NCOL, m[2], 1, m[2], 1);
}
printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
printf("--> result (naive): %f\n", res[0]);
@@ -1165,6 +1268,54 @@ test·level3()
}
void
+test·level2()
+{
+ int i, j, n, it;
+ clock_t t;
+
+ double *x, *y, *m;
+ double tprof[2];
+
+ rng·init(0);
+
+ tprof[0] = 0, tprof[1] = 0;
+ x = malloc(sizeof(*x)*NCOL);
+ y = malloc(sizeof(*x)*NCOL);
+ m = malloc(sizeof(*x)*NCOL*(NCOL+1)/2);
+
+ for (it = 0; it < NITER; it++) {
+ n = 0;
+ for (i = 0; i < NCOL; i++) {
+ y[i] = rng·random();
+ for (j = i; j < NCOL; j++) {
+ m[n++] = rng·random() + .1; // To ensure not singular
+ }
+ }
+
+ memcpy(x, y, NCOL * sizeof(*x));
+
+ t = clock();
+ blas·tpsv(0, NCOL, m, x);
+ t = clock() - t;
+ tprof[0] += 1000.*t/CLOCKS_PER_SEC;
+
+ t = clock();
+ cblas_dtpsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, NCOL, m, y, 1);
+ t = clock() - t;
+ tprof[1] += 1000.*t/CLOCKS_PER_SEC;
+
+ for (i = 0; i < NCOL; i++) {
+ if (math·abs(x[i] - y[i])/math·abs(x[i]) > 1e-5) {
+ errorf("failure at index %d: %f != %f", i, x[i], y[i]);
+ }
+ }
+ }
+
+ printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
+ printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
+}
+
+void
print·array(int n, double *x)
{
double *end;
@@ -1236,47 +1387,46 @@ test·level1()
error
main()
{
- int i, j, n, it;
+ int i, n;
+ double *x, *y;
+ double res[2], tprof[2];
clock_t t;
- double *x, *y, *m;
- double tprof[2];
-
- rng·init(0);
-
- tprof[0] = 0, tprof[1] = 0;
x = malloc(sizeof(*x)*NCOL);
y = malloc(sizeof(*x)*NCOL);
- m = malloc(sizeof(*x)*NCOL*(NCOL+1)/2);
+ rng·init(0);
- for (it = 0; it < NITER; it++) {
- n = 0;
+ for (n = 0; n < NITER; n++) {
for (i = 0; i < NCOL; i++) {
+ x[i] = rng·random();
y[i] = rng·random();
- for (j = i; j < NCOL; j++) {
- m[n++] = rng·random() + .1; // To ensure not singular
- }
}
- memcpy(x, y, NCOL * sizeof(*x));
-
- t = clock();
- blas·tpsv(0, NCOL, m, x);
- t = clock() - t;
- tprof[0] += 1000.*t/CLOCKS_PER_SEC;
-
- t = clock();
- cblas_dtpsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, NCOL, m, y, 1);
- t = clock() - t;
+ t = clock();
+ res[1] += cblas_ddot(NCOL/4, x, 4, y, 4);
+ t = clock() - t;
tprof[1] += 1000.*t/CLOCKS_PER_SEC;
- for (i = 0; i < NCOL; i++) {
- if (math·abs(x[i] - y[i])/math·abs(x[i]) > 1e-5) {
- errorf("failure at index %d: %f != %f", i, x[i], y[i]);
- }
- }
+ t = clock();
+ res[0] += blas·dot(NCOL/4, x, 4, y, 4);
+ t = clock() - t;
+ tprof[0] += 1000.*t/CLOCKS_PER_SEC;
}
+ printf("%f, %f\n", res[0], res[1]);
printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
+
+ double a, b, c, s;
+
+ a = 10.234, b = 2.;
+ cblas_drotg(&a, &b, &c, &s);
+ printf("%f, %f, %f, %f\n", a, b, c, s);
+
+ a = 10.234, b = 2.;
+ blas·rotg(&a, &b, &c, &s);
+ printf("%f, %f, %f, %f\n", a, b, c, s);
+
+ return 0;
+
}