aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--include/libmath.h2
-rw-r--r--sys/libmath/blas.c128
2 files changed, 83 insertions, 47 deletions
diff --git a/include/libmath.h b/include/libmath.h
index ecce28e..b148065 100644
--- a/include/libmath.h
+++ b/include/libmath.h
@@ -160,8 +160,8 @@ int blas·argmin(int len, double *x, int inc);
/* level 2 */
void blas·tpmv(blas·Flag f, int n, double *m, double *x);
-void blas·tpsv(blas·Flag f, int n, double *m, double *x);
error blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double *y) ;
+void blas·tpsv(blas·Flag f, int n, double *m, double *x);
void blas·ger(int nrow, int ncol, double a, double *x, double *y, double *m);
void blas·her(int n, double a, double *x, double *m);
void blas·syr(int nrow, int ncol, double a, double *x, double *m);
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index 1e1e1c7..7056c08 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -1110,7 +1110,7 @@ blas·tpsv(blas·Flag f, int n, double *m, double *x)
*/
static
void
-gemv_kernel4xN_4_avx2(int ncol, double **row, double *x, double *y)
+gemv_4xN_kernel4_avx2(int ncol, double **row, double *x, double *y)
{
int c;
__m128d hr;
@@ -1137,17 +1137,17 @@ gemv_kernel4xN_4_avx2(int ncol, double **row, double *x, double *y)
static
void
-gemv_kernel4xN_4(int ncol, double **row, double *x, double *y)
+gemv_4xN_kernel4(int n, double **row, double *x, double *y)
{
int c;
double res[4];
- res[0] = 0.;
- res[1] = 0.;
- res[2] = 0.;
- res[3] = 0.;
+ res[0] = 0.0;
+ res[1] = 0.0;
+ res[2] = 0.0;
+ res[3] = 0.0;
- for (c = 0; c < ncol; c += 4) {
+ for (c = 0; c < n; c += 4) {
res[0] += row[0][c+0]*x[c+0] + row[0][c+1]*x[c+1] + row[0][c+2]*x[c+2] + row[0][c+3]*x[c+3];
res[1] += row[1][c+0]*x[c+0] + row[1][c+1]*x[c+1] + row[1][c+2]*x[c+2] + row[1][c+3]*x[c+3];
res[2] += row[2][c+0]*x[c+0] + row[2][c+1]*x[c+1] + row[2][c+2]*x[c+2] + row[2][c+3]*x[c+3];
@@ -1162,17 +1162,36 @@ gemv_kernel4xN_4(int ncol, double **row, double *x, double *y)
static
void
-gemv_kernel1xN_4(int ncol, double *row, double *x, double *y)
+gemv_1xN_kernel4_avx2(int n, double *row, double *x, double *y)
+{
+ int c;
+ __m128d r128;
+ __m256d r256;
+
+ r256 = _mm256_setzero_pd();
+ for (c = 0; c < n; c += 4) {
+ r256 += _mm256_loadu_pd(row+c) * _mm256_loadu_pd(x+c);
+ }
+
+ r128 = _mm_add_pd(_mm256_extractf128_pd(r256, 0), _mm256_extractf128_pd(r256, 1));
+ r128 = _mm_hadd_pd(r128, r128);
+
+ *y = r128[0];
+}
+
+static
+void
+gemv_1xN_kernel4(int n, double *row, double *x, double *y)
{
int c;
double res;
res = 0.;
- for (c = 0; c < ncol; c += 4) {
+ for (c = 0; c < n; c += 4) {
res += row[c+0]*x[c+0] + row[c+1]*x[c+1] + row[c+2]*x[c+2] + row[c+3]*x[c+3];
}
- y[0] = res;
+ *y = res;
}
error
@@ -1181,32 +1200,40 @@ blas·gemv(int nrow, int ncol, double a, double *m, double *x, double b, double
int c, r, nr, nc;
double *row[4], res[4];
- nr = nrow & ~3;
- nc = ncol & ~3;
- for (r = 0; r < nrow; r += 4) {
- row[0] = m + (r * (ncol+0));
- row[1] = m + (r * (ncol+1));
- row[2] = m + (r * (ncol+2));
- row[3] = m + (r * (ncol+3));
+ nr = EVEN_BY(nrow, 4);
+ nc = EVEN_BY(ncol, 4);
+
+ for (r = 0; r < nr; r += 4) {
+ /* assumes row major layout */
+ row[0] = m + ((r+0) * ncol);
+ row[1] = m + ((r+1) * ncol);
+ row[2] = m + ((r+2) * ncol);
+ row[3] = m + ((r+3) * ncol);
- gemv_kernel4xN_4_avx2(ncol, row, x + r, res);
+ gemv_4xN_kernel4_avx2(nc, row, x, res);
for (c = nc; c < ncol; c++) {
- res[0] += row[0][c];
- res[1] += row[1][c];
- res[2] += row[2][c];
- res[3] += row[3][c];
+ res[0] += row[0][c]*x[c];
+ res[1] += row[1][c]*x[c];
+ res[2] += row[2][c]*x[c];
+ res[3] += row[3][c]*x[c];
}
- y[r+0] = res[0] + b*y[r+0];
- y[r+1] = res[1] + b*y[r+1];
- y[r+2] = res[2] + b*y[r+2];
- y[r+3] = res[3] + b*y[r+3];
+ y[r+0] = a*res[0] + b*y[r+0];
+ y[r+1] = a*res[1] + b*y[r+1];
+ y[r+2] = a*res[2] + b*y[r+2];
+ y[r+3] = a*res[3] + b*y[r+3];
}
for (; r < nrow; r++) {
- gemv_kernel1xN_4(nrow, m + (r * ncol), x + r, res);
- y[r] = res[0] + b*y[r];
+ row[0] = m + (r * ncol);
+ gemv_1xN_kernel4_avx2(nc, row[0], x, res);
+
+ for (c = nc; c < ncol; c++) {
+ res[0] += row[0][c]*x[c];
+ }
+
+ y[r] = a*res[0] + b*y[r];
}
return 0;
@@ -1343,9 +1370,9 @@ blas·trsm(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2)
{
}
-#define NITER 3000
-#define NCOL 5000007
-#define NROW 57
+#define NITER 1000
+#define NCOL 1005
+#define NROW 1005
error
test·level3()
@@ -1416,40 +1443,46 @@ test·level2()
int i, j, n, it;
clock_t t;
- double *x, *y, *m;
+ double *x, *y, *z, *m;
double tprof[2];
rng·init(0);
- tprof[0] = 0, tprof[1] = 0;
+ tprof[0] = 0;
+ tprof[1] = 0;
+
x = malloc(sizeof(*x)*NCOL);
y = malloc(sizeof(*x)*NCOL);
- m = malloc(sizeof(*x)*NCOL*(NCOL+1)/2);
+ z = malloc(sizeof(*x)*NCOL);
+ m = malloc(sizeof(*x)*NROW*NCOL);
for (it = 0; it < NITER; it++) {
n = 0;
- for (i = 0; i < NCOL; i++) {
+ for (i = 0; i < NROW; i++) {
+ x[i] = rng·random();
y[i] = rng·random();
- for (j = i; j < NCOL; j++) {
- m[n++] = rng·random() + .1; // To ensure not singular
+ for (j = 0; j < NCOL; j++) {
+ m[n++] = rng·random() + .1;
}
}
- memcpy(x, y, NCOL * sizeof(*x));
+ memcpy(z, y, NCOL * sizeof(*y));
t = clock();
- blas·tpsv(0, NCOL, m, x);
+ blas·gemv(NROW, NCOL, 2, m, x, 0.0, y);
t = clock() - t;
+
tprof[0] += 1000.*t/CLOCKS_PER_SEC;
t = clock();
- cblas_dtpsv(CblasRowMajor, CblasUpper, CblasNoTrans, CblasNonUnit, NCOL, m, y, 1);
+ cblas_dgemv(CblasRowMajor, CblasNoTrans, NROW, NCOL, 2, m, NROW, x, 1, 0.0, z, 1);
t = clock() - t;
+
tprof[1] += 1000.*t/CLOCKS_PER_SEC;
for (i = 0; i < NCOL; i++) {
- if (math·abs(x[i] - y[i])/math·abs(x[i]) > 1e-5) {
- errorf("failure at index %d: %f != %f", i, x[i], y[i]);
+ if (math·abs(z[i] - y[i])/math·abs(x[i]) > 1e-5) {
+ errorf("failure at index %d: %f != %f", i, z[i], y[i]);
}
}
}
@@ -1529,7 +1562,7 @@ test·level1()
#define STEP 1
error
-main()
+test·argmax()
{
int i, n;
double *x, *y, *w, *z;
@@ -1575,9 +1608,12 @@ main()
// }
}
- printf("%f, %f\n", res[0], res[1]);
- printf("mean time/iteration (naive): %fms\n", tprof[0]/NITER);
- printf("mean time/iteration (oblas): %fms\n", tprof[1]/NITER);
+ return 0;
+}
+error
+main()
+{
+ test·level2();
return 0;
}