From 327ca20a2a89d2408b53ff7854982560304cb76c Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Fri, 8 May 2020 16:45:52 -0700 Subject: added more level 2 and 3 functions to blas implementation --- include/libmath.h | 65 ++++++++- sys/libmath/blas.c | 406 +++++++++++++++++++++++++++++++++++++++++++++------ sys/libmath/linalg.c | 5 + 3 files changed, 429 insertions(+), 47 deletions(-) create mode 100644 sys/libmath/linalg.c diff --git a/include/libmath.h b/include/libmath.h index ec15f6f..b0ae434 100644 --- a/include/libmath.h +++ b/include/libmath.h @@ -130,4 +130,67 @@ double math·trunc(double); float math·truncf(float); // ----------------------------------------------------------------------- -// linear algebra +// basic linear algebra compute kernels + +// TODO: think of better names +enum +{ + blas·LowerTri = 1u, + blas·Transpose = 2u, + blas·ConjTranspose = 4u, + blas·DiagOnes = 8u, + blas·LeftSide = 16u, +}; + +typedef uint32 blas·Flag; + +/* level 1 */ +void blas·rot(int len, double *x, double *y, double cos, double sin); +void blas·rotg(double *a, double *b, double *cos, double *sin); +error blas·rotm(int len, double *x, double *y, double p[5]); +void blas·scale(int len, double *x, double a); +void blas·copy(int len, double *x, double *y); +void blas·swap(int len, double *x, double *y); +void blas·axpy(int len, double a, double *x, double *y); +double blas·dot(int len, double *x, double *y); +int blas·argmax(int len, double *x); +int blas·argmin(int len, double *x); + +/* level 2 */ +void blas·tpmv(blas·Flag f, int n, double *m, double *x); +void blas·tpsv(blas·Flag f, int n, double *m, double *x); +error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ; +void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m); +void blas·her(int n, double a, double *x, double *m); +void blas·syr(int nrow, int ncol, double a, double *x, double *m); + +/* level 3 */ +void blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3); +void blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); +void blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2); + +// ----------------------------------------------------------------------- +// higher level linear algebra + +struct linalg·Header +{ + void *h; + mem·Allocator heap; + + double *data; +}; + +typedef struct math·Vector +{ + int len; + struct linalg·Header; +} math·Vector; + +typedef struct math·Matrix +{ + int dim[2]; + blas·Flag kind; + struct linalg·Header; +} math·Matrix; + +// TODO: tensor ala numpy diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index f12e3e2..227715c 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -276,6 +276,7 @@ blas·copy(int len, double *x, double *y) * y <=> x */ +static void swap_kernel8_avx2(int n, double *x, double *y) { @@ -294,6 +295,7 @@ swap_kernel8_avx2(int n, double *x, double *y) } } +static void swap_kernel8(int n, double *x, double *y) { @@ -462,6 +464,8 @@ blas·dot(int len, double *x, double *y) int i, n; double res; + if (len == 0) return 0; + n = len & ~15; // neat trick res = dot_kernel8_fma3(n, x, y); @@ -483,7 +487,7 @@ argmax_kernel8_avx2(int n, double *x) register int i, msk, maxidx[4]; __m256d val, cmp, max; - maxidx[0] = 0, maxidx[1] = 1, maxidx[2] = 2, maxidx[3] = 3; + maxidx[0] = 0, maxidx[1] = 0, maxidx[2] = 0, maxidx[3] = 0; max = _mm256_loadu_pd(x); for (i = 0; i < n; i += 4) { @@ -493,102 +497,102 @@ argmax_kernel8_avx2(int n, double *x) switch (msk) { case 1: max = _mm256_blend_pd(max, val, 1); - maxidx[0] = i+0; + maxidx[0] = i; break; case 2: max = _mm256_blend_pd(max, val, 2); - maxidx[1] = i+1; + maxidx[1] = i; break; case 3: max = _mm256_blend_pd(max, val, 3); - maxidx[0] = i+0; - maxidx[1] = i+1; + maxidx[0] = i; + maxidx[1] = i; break; case 4: max = _mm256_blend_pd(max, val, 4); - maxidx[2] = i+2; + maxidx[2] = i; break; case 5: max = _mm256_blend_pd(max, val, 5); - maxidx[2] = i+2; - maxidx[0] = i+0; + maxidx[2] = i; + maxidx[0] = i; break; case 6: max = _mm256_blend_pd(max, val, 6); - maxidx[2] = i+2; - maxidx[1] = i+1; + maxidx[2] = i; + maxidx[1] = i; break; case 7: max = _mm256_blend_pd(max, val, 7); - maxidx[2] = i+2; - maxidx[1] = i+1; - maxidx[0] = i+0; + maxidx[2] = i; + maxidx[1] = i; + maxidx[0] = i; break; case 8: max = _mm256_blend_pd(max, val, 8); - maxidx[3] = i+3; + maxidx[3] = i; break; case 9: max = _mm256_blend_pd(max, val, 9); - maxidx[3] = i+3; - maxidx[0] = i+0; + maxidx[3] = i; + maxidx[0] = i; break; case 10: max = _mm256_blend_pd(max, val, 10); - maxidx[3] = i+3; - maxidx[1] = i+1; + maxidx[3] = i; + maxidx[1] = i; break; case 11: max = _mm256_blend_pd(max, val, 11); - maxidx[3] = i+3; - maxidx[1] = i+1; - maxidx[0] = i+0; + maxidx[3] = i; + maxidx[1] = i; + maxidx[0] = i; break; case 12: max = _mm256_blend_pd(max, val, 12); - maxidx[3] = i+3; - maxidx[2] = i+2; + maxidx[3] = i; + maxidx[2] = i; break; case 13: max = _mm256_blend_pd(max, val, 13); - maxidx[3] = i+3; - maxidx[2] = i+2; - maxidx[0] = i+0; + maxidx[3] = i; + maxidx[2] = i; + maxidx[0] = i; break; case 14: max = _mm256_blend_pd(max, val, 14); - maxidx[3] = i+3; - maxidx[2] = i+2; - maxidx[1] = i+1; + maxidx[3] = i; + maxidx[2] = i; + maxidx[1] = i; break; case 15: max = _mm256_blend_pd(max, val, 15); - maxidx[3] = i+3; - maxidx[2] = i+2; - maxidx[1] = i+1; - maxidx[0] = i+0; + maxidx[3] = i; + maxidx[2] = i; + maxidx[1] = i; + maxidx[0] = i; break; case 0: default: ; } } - maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1]; - maxidx[1] = (x[maxidx[2]] > x[maxidx[3]]) ? maxidx[2] : maxidx[3]; + maxidx[0] = (x[maxidx[0]+0] > x[maxidx[1]+1]) ? maxidx[0]+0 : maxidx[1]+1; + maxidx[1] = (x[maxidx[2]+2] > x[maxidx[3]+3]) ? maxidx[2]+2 : maxidx[3]+3; maxidx[0] = (x[maxidx[0]] > x[maxidx[1]]) ? maxidx[0] : maxidx[1]; return maxidx[0]; @@ -597,7 +601,7 @@ argmax_kernel8_avx2(int n, double *x) int argmax_kernel8(int n, double *x) { -#define SET(d) idx[d] = 0, max[d] = x[d] +#define SET(d) idx[d] = d, max[d] = x[d] #define PUT(d) if (x[i+d] > max[d]) idx[d] = i+d, max[d] = x[i+d] int i, idx[8]; double max[8]; @@ -641,6 +645,171 @@ blas·argmax(int len, double *x) return n; } +int +argmin_kernel8_avx2(int n, double *x) +{ + register int i, msk, minidx[4]; + __m256d val, cmp, min; + + minidx[0] = 0, minidx[1] = 0, minidx[2] = 0, minidx[3] = 0; + min = _mm256_loadu_pd(x); + + for (i = 0; i < n; i += 4) { + val = _mm256_loadu_pd(x+i); + cmp = _mm256_cmp_pd(val, min, _CMP_LT_OS); + msk = _mm256_movemask_pd(cmp); + switch (msk) { + case 1: + min = _mm256_blend_pd(min, val, 1); + minidx[0] = i; + break; + + case 2: + min = _mm256_blend_pd(min, val, 2); + minidx[1] = i; + break; + + case 3: + min = _mm256_blend_pd(min, val, 3); + minidx[0] = i; + minidx[1] = i; + break; + + case 4: + min = _mm256_blend_pd(min, val, 4); + minidx[2] = i; + break; + + case 5: + min = _mm256_blend_pd(min, val, 5); + minidx[2] = i; + minidx[0] = i; + break; + + case 6: + min = _mm256_blend_pd(min, val, 6); + minidx[2] = i; + minidx[1] = i; + break; + + case 7: + min = _mm256_blend_pd(min, val, 7); + minidx[2] = i; + minidx[1] = i; + minidx[0] = i; + break; + + case 8: + min = _mm256_blend_pd(min, val, 8); + minidx[3] = i; + break; + + case 9: + min = _mm256_blend_pd(min, val, 9); + minidx[3] = i; + minidx[0] = i; + break; + + case 10: + min = _mm256_blend_pd(min, val, 10); + minidx[3] = i; + minidx[1] = i; + break; + + case 11: + min = _mm256_blend_pd(min, val, 11); + minidx[3] = i; + minidx[1] = i; + minidx[0] = i; + break; + + case 12: + min = _mm256_blend_pd(min, val, 12); + minidx[3] = i; + minidx[2] = i; + break; + + case 13: + min = _mm256_blend_pd(min, val, 13); + minidx[3] = i; + minidx[2] = i; + minidx[0] = i; + break; + + case 14: + min = _mm256_blend_pd(min, val, 14); + minidx[3] = i; + minidx[2] = i; + minidx[1] = i; + break; + + case 15: + min = _mm256_blend_pd(min, val, 15); + minidx[3] = i; + minidx[2] = i; + minidx[1] = i; + minidx[0] = i; + break; + + case 0: + default: ; + } + } + minidx[0] = (x[minidx[0]+0] < x[minidx[1]+1]) ? minidx[0]+0 : minidx[1]+1; + minidx[1] = (x[minidx[2]+2] < x[minidx[3]+3]) ? minidx[2]+2 : minidx[3]+3; + minidx[0] = (x[minidx[0]] < x[minidx[1]]) ? minidx[0] : minidx[1]; + + return minidx[0]; +} + + +int +argmin_kernel8(int n, double *x) +{ +#define SET(d) idx[d] = d, min[d] = x[d] +#define PUT(d) if (x[i+d] < min[d]) idx[d] = i+d, min[d] = x[i+d] + int i, idx[8]; + double min[8]; + + SET(0); SET(1); SET(2); SET(3); + SET(4); SET(5); SET(6); SET(7); + + for (i = 0; i < n; i += 8) { + PUT(0); PUT(1); PUT(2); PUT(3); + PUT(4); PUT(5); PUT(6); PUT(7); + } + + n = 0; + for (i = 1; i < 8; i++) { + if (min[i] < min[n]) { + n = i; + } + } + return idx[n]; +#undef PUT +#undef SET +} + +int +blas·argmin(int len, double *x) +{ + int i, n; + double min; + + i = len & ~7; + n = argmin_kernel8_avx2(i, x); + + min = x[n]; + for (; i < len; i++) { + if (x[i] < min) { + n = i; + min = x[i]; + } + } + + return n; +} + // ----------------------------------------------------------------------- // level two @@ -650,8 +819,47 @@ blas·argmax(int len, double *x) */ +// NOTE: All triangular matrix methods are assumed packed and upper for now! + /* - * Affine transformation + * triangular shaped transformation + * x = Mx + * @M: square triangular + * TODO(PERF): Diagnose speed issues + * TODO: Finish all other flag cases! + */ +void +blas·tpmv(blas·Flag f, int n, double *m, double *x) +{ + int i; + for (i = 0; i < n; m += (n-i), ++x, ++i) { + *x = blas·dot(n-i, m, x); + } +} + +/* + * solve triangular set of equations + * x = M^{-1}b + * @M: square triangular + * TODO(PERF): Diagnose speed issues + * TODO: Finish all other flag cases! + */ +void +blas·tpsv(blas·Flag f, int n, double *m, double *x) +{ + int i; + double r; + + x += (n - 1); + m += ((n * (n+1))/2 - 1); + for (i = n-1; i >= 0; --i, --x, m -= (n-i)) { + r = blas·dot(n-i-1, m+1, x+1); + *x = (*x - r) / *m; + } +} + +/* + * general affine transformation * y = aMx + by */ static @@ -776,6 +984,17 @@ blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m) } } +/* + * rank one addition + * M = ax(x^T) + M + */ + +void +blas·her(int n, double a, double *x, double *m) +{ + blas·ger(n, n, a, x, x, m); +} + /* * symmetric rank one addition * M = ax(x^T) + M @@ -794,6 +1013,7 @@ blas·syr(int nrow, int ncol, double a, double *x, double *m) } } + // ----------------------------------------------------------------------- // level three @@ -836,12 +1056,53 @@ blas·gemm(int n1, int n2, int n3, double a, double *m1, double *m2, double b, d } } -#define NITER 30000 -#define NCOL 50007 -#define NROW 107 +/* + * triangular matrix multiplication + * m2 = a * m1 * m2 _OR_ a * m2 * m1 + * m1 is assumed triangular + * einstein notation: + * m2_{ij} = a m1_{ik} m2_{kj} _OR_ a m1_{kj} m2_{ik} + * + * nrow = # rows of m2 + * ncol = # cols of m2 + * TODO(PERF): make compute kernel + * TODO: finish all other flags + */ +void +blas·trmm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2) +{ + int i, j, k, len; + + for (i = 0; i < nrow; i++) { + for (j = 0; j < ncol; j++) { + for (k = i; k < ncol; k++) { + m2[i + ncol*j] += a * m1[i + nrow*k] * m2[k + ncol*j]; + } + } + } +} + +/* + * solve triangular matrix system of equations + * m2 = a * m1^{-1L} _OR_ a * m2 * m1 + * m1 is assumed triangular + * + * nrow = # rows of m2 + * ncol = # cols of m2 + * TODO: complete stub + * TODO: finish all other flags + */ +void +blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2) +{ +} + +#define NITER 10000 +#define NCOL 5007 +#define NROW 5007 error -test·gemm() +test·level3() { int i, n; clock_t t; @@ -915,7 +1176,7 @@ print·array(int n, double *x) } error -main() +test·level1() { int ai, ai2, i, n; double *x, *y; @@ -932,24 +1193,27 @@ main() for (n = 0; n < NITER; n++) { for (i = 0; i < NCOL; i++) { - y[n] = rng·random(); + y[i] = rng·random(); } memcpy(x, y, sizeof(*x)*NCOL); t = clock(); - ai = blas·argmax(NCOL, x); + ai = blas·argmin(NCOL, x); t = clock() - t; tprof[0] += 1000.*t/CLOCKS_PER_SEC; + if (n == 20729) { + printf("[%d]=%f vs [%d]=%f\n", 74202, x[74202], 3, x[3]); + } memcpy(x, y, sizeof(*x)*NCOL); t = clock(); - ai2 = cblas_idamax(NCOL, x, 1); + ai2 = cblas_idamin(NCOL, x, 1); t = clock() - t; tprof[1] += 1000.*t/CLOCKS_PER_SEC; if (ai != ai2) { - printf("iteration %d: %d not equal to %d\n", n, ai, ai2); + printf("iteration %d: %d not equal to %d. %f vs %f\n", n, ai, ai2, x[ai], x[ai2]); } } @@ -965,4 +1229,54 @@ main() a = 10.234, b = 2.; blas·rotg(&a, &b, &c, &s); printf("%f, %f, %f, %f\n", a, b, c, s); + + return 0; +} + +error +main() +{ + int i, j, n, it; + clock_t t; + + double *x, *y, *m; + double tprof[2]; + + rng·init(0); + + tprof[0] = 0, tprof[1] = 0; + x = malloc(sizeof(*x)*NCOL); + y = malloc(sizeof(*x)*NCOL); + m = malloc(sizeof(*x)*NCOL*(NCOL+1)/2); + + for (it = 0; it < NITER; it++) { + n = 0; + for (i = 0; i < NCOL; i++) { + y[i] = rng·random(); + for (j = i; j < NCOL; j++) { + m[n++] = rng·random() + .1; // To ensure not singular + } + } + + memcpy(x, y, NCOL * sizeof(*x)); + + t = clock(); + blas·tpsv(0, NCOL, m, x); + t = clock() - t; + tprof[0] += 1000.*t/CLOCKS_PER_SEC; + + t = clock(); + cblas_dtpsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, NCOL, m, y, 1); + t = clock() - t; + tprof[1] += 1000.*t/CLOCKS_PER_SEC; + + for (i = 0; i < NCOL; i++) { + if (math·abs(x[i] - y[i])/math·abs(x[i]) > 1e-5) { + errorf("failure at index %d: %f != %f", i, x[i], y[i]); + } + } + } + + printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER); + printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER); } diff --git a/sys/libmath/linalg.c b/sys/libmath/linalg.c new file mode 100644 index 0000000..57f799b --- /dev/null +++ b/sys/libmath/linalg.c @@ -0,0 +1,5 @@ +#include +#include +#include + + -- cgit v1.2.1