From 463ed852261da4d1dd1b859fa717a1d683306c9d Mon Sep 17 00:00:00 2001 From: Nicholas Noll Date: Thu, 14 May 2020 18:15:23 -0700 Subject: feat: begun work on final blas level 2 --- sys/libmath/blas.c | 27 ++-- sys/libmath/blas1.c | 214 +++++++++++++++------------- sys/libmath/gen1.py | 5 +- sys/libmath/gen2.py | 390 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 528 insertions(+), 108 deletions(-) create mode 100755 sys/libmath/gen2.py (limited to 'sys') diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c index f5392c8..43b0e70 100644 --- a/sys/libmath/blas.c +++ b/sys/libmath/blas.c @@ -6,37 +6,44 @@ #include -#define LEN 2000000 -#define NIT 2000 -#define INC 2 +#define NCOL 2000 +#define NROW 2000 +#define NIT 2000 +#define INC 2 error main() { - int i, nit; - double *x, *y, res[2]; + int i, j, nit; + double *x, *y, *m, res[2]; clock_t t; double tprof[2] = { 0 }; rng·init(0); - x = malloc(sizeof(*x)*LEN); - y = malloc(sizeof(*x)*LEN); + x = malloc(sizeof(*x)*NCOL); + y = malloc(sizeof(*x)*NROW); + m = malloc(sizeof(*x)*NROW*NCOL); #define DO_0 t = clock(); \ - res[0] += blas·sumd(LEN/INC, x, INC); \ + blas·gerd(NROW/INC, NCOL/INC, 1.2, x, INC, y, INC, m, NCOL); \ + res[0] += m[0]; \ t = clock() - t; \ tprof[0] += 1000.*t/CLOCKS_PER_SEC; \ #define DO_1 t = clock(); \ - res[1] += cblas_dsum(LEN/INC, x, INC); \ + cblas_dger(CblasRowMajor, NROW/INC, NCOL/INC, 1.2, x, INC, y, INC, m, NCOL); \ + res[1] += m[0]; \ t = clock() - t; \ tprof[1] += 1000.*t/CLOCKS_PER_SEC; for (nit = 0; nit < NIT; nit++) { - for (i = 0; i < LEN; i++) { + for (i = 0; i < NROW; i++) { x[i] = rng·random(); y[i] = rng·random(); + for (j = 0; j < NCOL; j++) { + m[j + NCOL*i] = rng·random(); + } } switch (nit % 2) { diff --git a/sys/libmath/blas1.c b/sys/libmath/blas1.c index 8cc70eb..1907179 100644 --- a/sys/libmath/blas1.c +++ b/sys/libmath/blas1.c @@ -173,7 +173,12 @@ blas·argminf(int len, float *x, int incx) i = len; ix = argminf_s_kernel8(&i, incx, x); } - min = x[ix]; + for (; i < len; i++) { + if (x[incx * i] < min) { + ix = i; + min = x[incx * ix]; + } + } for (; i < len; i++) { if (x[i] < min) { ix = i; @@ -188,16 +193,16 @@ static int argmaxf_kernel16(int *ip, float *x) { + int ix[16]; float max[16]; int i; - int ix[16]; int len; for (i = 0; i < 16; i++) { - max[i] = x[0]; + ix[i] = 0; } for (i = 0; i < 16; i++) { - ix[i] = 0; + max[i] = x[0]; } len = *ip & ~15; for (i = 0; i < len; i += 16) { @@ -280,16 +285,16 @@ static int argmaxf_s_kernel8(int *ip, int incx, float *x) { + int ix[8]; float max[8]; int i; - int ix[8]; int len; for (i = 0; i < 8; i++) { - max[i] = x[0]; + ix[i] = 0; } for (i = 0; i < 8; i++) { - ix[i] = 0; + max[i] = x[0]; } len = *ip & ~7; for (i = 0; i < len; i += 8) { @@ -350,7 +355,12 @@ blas·argmaxf(int len, float *x, int incx) i = len; ix = argmaxf_s_kernel8(&i, incx, x); } - max = x[ix]; + for (; i < len; i++) { + if (x[incx * i] > max) { + ix = i; + max = x[incx * ix]; + } + } for (; i < len; i++) { if (x[i] > max) { ix = i; @@ -392,7 +402,7 @@ copyf_kernel16(int *ip, float *x, float *y) static void -copyf_s_kernel8(int *ip, int incx, float *x, int incy, float *y) +copyf_s_kernel8(int *ip, int incx, int incy, float *x, float *y) { int i; int len; @@ -421,16 +431,16 @@ blas·copyf(int len, float *x, int incx, float *y, int incy) copyf_kernel16(&i, x, y); } else { i = len; - copyf_s_kernel8(&i, incx, x, incy, y); + copyf_s_kernel8(&i, incx, incy, x, y); } for (; i < len; i++) { - y[i] = x[i]; + y[incy * i] = x[incx * i]; } } static void -axpyf_kernel16(int *ip, float a, float *x, float *y) +axpyf_kernel16(int *ip, float *x, float a, float *y) { int i; int len; @@ -459,7 +469,7 @@ axpyf_kernel16(int *ip, float a, float *x, float *y) static void -axpyf_s_kernel8(int *ip, float a, float *x, int incy, float *y, int incx) +axpyf_s_kernel8(int *ip, float *x, float a, int incx, float *y, int incy) { int i; int len; @@ -485,19 +495,19 @@ blas·axpyf(int len, float a, float *x, int incx, float *y, int incy) if (incx == 1 && incy == 1) { i = len; - axpyf_kernel16(&i, a, x, y); + axpyf_kernel16(&i, x, a, y); } else { i = len; - axpyf_s_kernel8(&i, a, x, incy, y, incx); + axpyf_s_kernel8(&i, x, a, incx, y, incy); } for (; i < len; i++) { - y[i] = y[i] + a * x[i]; + y[incy * i] = y[incy * i] + a * x[incx * i]; } } static void -axpbyf_kernel16(int *ip, float *x, float a, float *y, float b) +axpbyf_kernel16(int *ip, float a, float *y, float b, float *x) { int i; int len; @@ -526,7 +536,7 @@ axpbyf_kernel16(int *ip, float *x, float a, float *y, float b) static void -axpbyf_s_kernel8(int *ip, float *x, float a, float *y, float b, int incx, int incy) +axpbyf_s_kernel8(int *ip, float a, int incx, float *y, float b, int incy, float *x) { int i; int len; @@ -552,22 +562,22 @@ blas·axpbyf(int len, float a, float *x, int incx, float b, float *y, int incy) if (incx == 1 && incy == 1) { i = len; - axpbyf_kernel16(&i, x, a, y, b); + axpbyf_kernel16(&i, a, y, b, x); } else { i = len; - axpbyf_s_kernel8(&i, x, a, y, b, incx, incy); + axpbyf_s_kernel8(&i, a, incx, y, b, incy, x); } for (; i < len; i++) { - y[i] = b * y[i] + a * x[i]; + y[incy * i] = b * y[incy * i] + a * x[incx * i]; } } static float -dotf_kernel16(int *ip, float *x, float *y) +dotf_kernel16(int *ip, float *y, float *x) { - int i; float sum[16]; + int i; int len; for (i = 0; i < 16; i++) { @@ -602,10 +612,10 @@ dotf_kernel16(int *ip, float *x, float *y) static float -dotf_s_kernel8(int *ip, float *x, int incy, float *y, int incx) +dotf_s_kernel8(int *ip, int incy, float *x, float *y, int incx) { - int i; float sum[8]; + int i; int len; for (i = 0; i < 8; i++) { @@ -638,13 +648,13 @@ blas·dotf(int len, float *x, int incx, float *y, int incy) if (incx == 1 && incy == 1) { i = len; - sum = dotf_kernel16(&i, x, y); + sum = dotf_kernel16(&i, y, x); } else { i = len; - sum = dotf_s_kernel8(&i, x, incy, y, incx); + sum = dotf_s_kernel8(&i, incy, x, y, incx); } for (; i < len; i++) { - sum += x[i] * y[i]; + sum += x[incx * i] * y[incy * i]; } return sum; @@ -654,8 +664,8 @@ static float sumf_kernel16(int *ip, float *x) { - int i; float sum[16]; + int i; int len; for (i = 0; i < 16; i++) { @@ -690,7 +700,7 @@ sumf_kernel16(int *ip, float *x) static float -sumf_s_kernel8(int *ip, float *x, int incx) +sumf_s_kernel8(int *ip, int incx, float *x) { float sum[8]; int i; @@ -729,10 +739,10 @@ blas·sumf(int len, float *x, int incx) sum = sumf_kernel16(&i, x); } else { i = len; - sum = sumf_s_kernel8(&i, x, incx); + sum = sumf_s_kernel8(&i, incx, x); } for (; i < len; i++) { - sum += x[i]; + sum += x[incx * i]; } return sum; @@ -742,8 +752,8 @@ static float normf_kernel16(int *ip, float *x) { - float nrm[16]; int i; + float nrm[16]; int len; for (i = 0; i < 16; i++) { @@ -780,8 +790,8 @@ static float normf_s_kernel8(int *ip, int incx, float *x) { - float nrm[8]; int i; + float nrm[8]; int len; for (i = 0; i < 8; i++) { @@ -820,7 +830,7 @@ blas·normf(int len, float *x, int incx) nrm = normf_s_kernel8(&i, incx, x); } for (; i < len; i++) { - nrm += x[i] * x[i]; + nrm += x[incx * i] * x[incx * i]; } return math·sqrtf(nrm); @@ -857,7 +867,7 @@ scalef_kernel16(int *ip, float *x, float a) static void -scalef_s_kernel8(int *ip, int incx, float *x, float a) +scalef_s_kernel8(int *ip, float *x, int incx, float a) { int i; int len; @@ -886,16 +896,16 @@ blas·scalef(int len, float *x, int incx, float a) scalef_kernel16(&i, x, a); } else { i = len; - scalef_s_kernel8(&i, incx, x, a); + scalef_s_kernel8(&i, x, incx, a); } for (; i < len; i++) { - x[i] = a * x[i]; + x[incx * i] = a * x[incx * i]; } } static void -rotf_kernel16(int *ip, float *y, float cos, float sin, float *x) +rotf_kernel16(int *ip, float sin, float *x, float cos, float *y) { int i; float tmp[16]; @@ -925,7 +935,7 @@ rotf_kernel16(int *ip, float *y, float cos, float sin, float *x) static void -rotf_s_kernel8(int *ip, float *y, float sin, float cos, float *x, int incx, int incy) +rotf_s_kernel8(int *ip, int incy, float cos, float sin, int incx, float *x, float *y) { int i; float tmp[8]; @@ -953,19 +963,19 @@ blas·rotf(int len, float *x, int incx, float *y, int incy, float cos, float sin if (incx == 1 && incy == 1) { i = len; - rotf_kernel16(&i, y, cos, sin, x); + rotf_kernel16(&i, sin, x, cos, y); } else { i = len; - rotf_s_kernel8(&i, y, sin, cos, x, incx, incy); + rotf_s_kernel8(&i, incy, cos, sin, incx, x, y); } for (; i < len; i++) { - tmp = x[i], x[i] = cos * x[i] + sin * y[i], y[i] = cos * y[i] - sin * tmp; + tmp = x[incx * i], x[incx * i] = cos * x[incx * i] + sin * y[incy * i], y[incy * i] = cos * y[incy * i] - sin * tmp; } } static void -rotgf_kernel16(int *ip, float H[5], float *y, float *x) +rotgf_kernel16(int *ip, float *x, float H[5], float *y) { float tmp[16]; int i; @@ -995,7 +1005,7 @@ rotgf_kernel16(int *ip, float H[5], float *y, float *x) static void -rotgf_s_kernel8(int *ip, int incy, int incx, float H[5], float *y, float *x) +rotgf_s_kernel8(int *ip, int incx, int incy, float *x, float H[5], float *y) { float tmp[8]; int i; @@ -1023,13 +1033,13 @@ blas·rotgf(int len, float *x, int incx, float *y, int incy, float H[5]) if (incx == 1 && incy == 1) { i = len; - rotgf_kernel16(&i, H, y, x); + rotgf_kernel16(&i, x, H, y); } else { i = len; - rotgf_s_kernel8(&i, incy, incx, H, y, x); + rotgf_s_kernel8(&i, incx, incy, x, H, y); } for (; i < len; i++) { - tmp = x[i], x[i] = H[1] * x[i] + H[2] * y[i], y[i] = H[3] * y[i] + H[4] * tmp; + tmp = x[incx * i], x[incx * i] = H[1] * x[incx * i] + H[2] * y[incy * i], y[incy * i] = H[3] * y[incy * i] + H[4] * tmp; } } @@ -1037,16 +1047,16 @@ static int argmind_kernel16(int *ip, double *x) { - int ix[16]; double min[16]; int i; + int ix[16]; int len; for (i = 0; i < 16; i++) { - ix[i] = 0; + min[i] = x[0]; } for (i = 0; i < 16; i++) { - min[i] = x[0]; + ix[i] = 0; } len = *ip & ~15; for (i = 0; i < len; i += 16) { @@ -1129,16 +1139,16 @@ static int argmind_s_kernel8(int *ip, double *x, int incx) { - int ix[8]; double min[8]; int i; + int ix[8]; int len; for (i = 0; i < 8; i++) { - ix[i] = 0; + min[i] = x[0]; } for (i = 0; i < 8; i++) { - min[i] = x[0]; + ix[i] = 0; } len = *ip & ~7; for (i = 0; i < len; i += 8) { @@ -1199,7 +1209,12 @@ blas·argmind(int len, double *x, int incx) i = len; ix = argmind_s_kernel8(&i, x, incx); } - min = x[ix]; + for (; i < len; i++) { + if (x[incx * i] < min) { + ix = i; + min = x[incx * ix]; + } + } for (; i < len; i++) { if (x[i] < min) { ix = i; @@ -1214,8 +1229,8 @@ static int argmaxd_kernel16(int *ip, double *x) { - int i; int ix[16]; + int i; double max[16]; int len; @@ -1306,8 +1321,8 @@ static int argmaxd_s_kernel8(int *ip, int incx, double *x) { - int i; int ix[8]; + int i; double max[8]; int len; @@ -1376,7 +1391,12 @@ blas·argmaxd(int len, double *x, int incx) i = len; ix = argmaxd_s_kernel8(&i, incx, x); } - max = x[ix]; + for (; i < len; i++) { + if (x[incx * i] > max) { + ix = i; + max = x[incx * ix]; + } + } for (; i < len; i++) { if (x[i] > max) { ix = i; @@ -1418,7 +1438,7 @@ copyd_kernel16(int *ip, double *y, double *x) static void -copyd_s_kernel8(int *ip, double *y, double *x, int incy, int incx) +copyd_s_kernel8(int *ip, double *y, int incx, int incy, double *x) { int i; int len; @@ -1447,16 +1467,16 @@ blas·copyd(int len, double *x, int incx, double *y, int incy) copyd_kernel16(&i, y, x); } else { i = len; - copyd_s_kernel8(&i, y, x, incy, incx); + copyd_s_kernel8(&i, y, incx, incy, x); } for (; i < len; i++) { - y[i] = x[i]; + y[incy * i] = x[incx * i]; } } static void -axpyd_kernel16(int *ip, double a, double *y, double *x) +axpyd_kernel16(int *ip, double *x, double *y, double a) { int i; int len; @@ -1485,7 +1505,7 @@ axpyd_kernel16(int *ip, double a, double *y, double *x) static void -axpyd_s_kernel8(int *ip, double a, int incx, double *y, double *x, int incy) +axpyd_s_kernel8(int *ip, double *x, int incy, double *y, double a, int incx) { int i; int len; @@ -1511,13 +1531,13 @@ blas·axpyd(int len, double a, double *x, int incx, double *y, int incy) if (incx == 1 && incy == 1) { i = len; - axpyd_kernel16(&i, a, y, x); + axpyd_kernel16(&i, x, y, a); } else { i = len; - axpyd_s_kernel8(&i, a, incx, y, x, incy); + axpyd_s_kernel8(&i, x, incy, y, a, incx); } for (; i < len; i++) { - y[i] = y[i] + a * x[i]; + y[incy * i] = y[incy * i] + a * x[incx * i]; } } @@ -1552,7 +1572,7 @@ axpbyd_kernel16(int *ip, double a, double b, double *x, double *y) static void -axpbyd_s_kernel8(int *ip, double a, double b, double *x, int incx, int incy, double *y) +axpbyd_s_kernel8(int *ip, double a, int incy, double b, double *x, int incx, double *y) { int i; int len; @@ -1581,16 +1601,16 @@ blas·axpbyd(int len, double a, double *x, int incx, double b, double *y, int in axpbyd_kernel16(&i, a, b, x, y); } else { i = len; - axpbyd_s_kernel8(&i, a, b, x, incx, incy, y); + axpbyd_s_kernel8(&i, a, incy, b, x, incx, y); } for (; i < len; i++) { - y[i] = b * y[i] + a * x[i]; + y[incy * i] = b * y[incy * i] + a * x[incx * i]; } } static double -dotd_kernel16(int *ip, double *x, double *y) +dotd_kernel16(int *ip, double *y, double *x) { double sum[16]; int i; @@ -1628,7 +1648,7 @@ dotd_kernel16(int *ip, double *x, double *y) static double -dotd_s_kernel8(int *ip, int incx, double *x, int incy, double *y) +dotd_s_kernel8(int *ip, double *y, int incx, int incy, double *x) { double sum[8]; int i; @@ -1664,13 +1684,13 @@ blas·dotd(int len, double *x, int incx, double *y, int incy) if (incx == 1 && incy == 1) { i = len; - sum = dotd_kernel16(&i, x, y); + sum = dotd_kernel16(&i, y, x); } else { i = len; - sum = dotd_s_kernel8(&i, incx, x, incy, y); + sum = dotd_s_kernel8(&i, y, incx, incy, x); } for (; i < len; i++) { - sum += x[i] * y[i]; + sum += x[incx * i] * y[incy * i]; } return sum; @@ -1716,7 +1736,7 @@ sumd_kernel16(int *ip, double *x) static double -sumd_s_kernel8(int *ip, double *x, int incx) +sumd_s_kernel8(int *ip, int incx, double *x) { int i; double sum[8]; @@ -1755,10 +1775,10 @@ blas·sumd(int len, double *x, int incx) sum = sumd_kernel16(&i, x); } else { i = len; - sum = sumd_s_kernel8(&i, x, incx); + sum = sumd_s_kernel8(&i, incx, x); } for (; i < len; i++) { - sum += x[i]; + sum += x[incx * i]; } return sum; @@ -1768,8 +1788,8 @@ static double normd_kernel16(int *ip, double *x) { - double nrm[16]; int i; + double nrm[16]; int len; for (i = 0; i < 16; i++) { @@ -1806,8 +1826,8 @@ static double normd_s_kernel8(int *ip, int incx, double *x) { - int i; double nrm[8]; + int i; int len; for (i = 0; i < 8; i++) { @@ -1846,7 +1866,7 @@ blas·normd(int len, double *x, int incx) nrm = normd_s_kernel8(&i, incx, x); } for (; i < len; i++) { - nrm += x[i] * x[i]; + nrm += x[incx * i] * x[incx * i]; } return math·sqrt(nrm); @@ -1883,7 +1903,7 @@ scaled_kernel16(int *ip, double a, double *x) static void -scaled_s_kernel8(int *ip, double a, int incx, double *x) +scaled_s_kernel8(int *ip, double *x, int incx, double a) { int i; int len; @@ -1912,19 +1932,19 @@ blas·scaled(int len, double *x, int incx, double a) scaled_kernel16(&i, a, x); } else { i = len; - scaled_s_kernel8(&i, a, incx, x); + scaled_s_kernel8(&i, x, incx, a); } for (; i < len; i++) { - x[i] = a * x[i]; + x[incx * i] = a * x[incx * i]; } } static void -rotd_kernel16(int *ip, double *x, double *y, double cos, double sin) +rotd_kernel16(int *ip, double sin, double cos, double *x, double *y) { - int i; double tmp[16]; + int i; int len; len = *ip & ~15; @@ -1951,10 +1971,10 @@ rotd_kernel16(int *ip, double *x, double *y, double cos, double sin) static void -rotd_s_kernel8(int *ip, double *y, int incx, double sin, int incy, double *x, double cos) +rotd_s_kernel8(int *ip, double sin, double *x, double *y, double cos, int incy, int incx) { - int i; double tmp[8]; + int i; int len; len = *ip & ~7; @@ -1979,22 +1999,22 @@ blas·rotd(int len, double *x, int incx, double *y, int incy, double cos, double if (incx == 1 && incy == 1) { i = len; - rotd_kernel16(&i, x, y, cos, sin); + rotd_kernel16(&i, sin, cos, x, y); } else { i = len; - rotd_s_kernel8(&i, y, incx, sin, incy, x, cos); + rotd_s_kernel8(&i, sin, x, y, cos, incy, incx); } for (; i < len; i++) { - tmp = x[i], x[i] = cos * x[i] + sin * y[i], y[i] = cos * y[i] - sin * tmp; + tmp = x[incx * i], x[incx * i] = cos * x[incx * i] + sin * y[incy * i], y[incy * i] = cos * y[incy * i] - sin * tmp; } } static void -rotgd_kernel16(int *ip, double *y, double H[5], double *x) +rotgd_kernel16(int *ip, double H[5], double *y, double *x) { - double tmp[16]; int i; + double tmp[16]; int len; len = *ip & ~15; @@ -2021,10 +2041,10 @@ rotgd_kernel16(int *ip, double *y, double H[5], double *x) static void -rotgd_s_kernel8(int *ip, int incx, double *y, int incy, double H[5], double *x) +rotgd_s_kernel8(int *ip, double H[5], int incx, double *y, double *x, int incy) { - double tmp[8]; int i; + double tmp[8]; int len; len = *ip & ~7; @@ -2049,13 +2069,13 @@ blas·rotgd(int len, double *x, int incx, double *y, int incy, double H[5]) if (incx == 1 && incy == 1) { i = len; - rotgd_kernel16(&i, y, H, x); + rotgd_kernel16(&i, H, y, x); } else { i = len; - rotgd_s_kernel8(&i, incx, y, incy, H, x); + rotgd_s_kernel8(&i, H, incx, y, x, incy); } for (; i < len; i++) { - tmp = x[i], x[i] = H[1] * x[i] + H[2] * y[i], y[i] = H[3] * y[i] + H[4] * tmp; + tmp = x[incx * i], x[incx * i] = H[1] * x[incx * i] + H[2] * y[incy * i], y[incy * i] = H[3] * y[incy * i] + H[4] * tmp; } } diff --git a/sys/libmath/gen1.py b/sys/libmath/gen1.py index b0f9ecc..936bc50 100755 --- a/sys/libmath/gen1.py +++ b/sys/libmath/gen1.py @@ -1,3 +1,5 @@ +#!/bin/python + from C import * NUNROLL = 16 @@ -12,7 +14,8 @@ def typeify(string, kind): def fini(func, loop, strided, calls, ret=[]): func.execute(*calls[:2]) - func = Strided(func, loop, NUNROLL//2, strided, *ret) + func, scall = Strided(func, loop, NUNROLL//2, strided, *ret) + calls[2] = scall[0] func.execute(*calls[2:]) func.emit() diff --git a/sys/libmath/gen2.py b/sys/libmath/gen2.py new file mode 100755 index 0000000..6ce2a12 --- /dev/null +++ b/sys/libmath/gen2.py @@ -0,0 +1,390 @@ +from C import * +import copy + +ROW = 4 +COL = 4 + +def pkg(name): + return f"blas·{name}" + +def typeify(string, kind): + if (kind == Float32): + return f"{string}f" + if (kind == Float64): + return f"{string}d" + +# ------------------------------------------------------------------------ +# Helpers (abandoning the automatic unroll from level 1) + +def toarray(len: int, *args): + return [Param(Array(arg.type, len), arg.name) for arg in args] + +def TryIndex(x, i): + if IsArrayType(x.var.type): + return Index(x, i) + return x + +def AddElts(root, *vars): + for var in vars: + root = Add(root ,var) + return root + +def Round(store, number, by): + return Set(store, And(number, Negate(I(by-1)))) + +def UnitIncs(root, *incs): + root = EQ(root, I(1)) + for inc in incs: + root = AndAnd(root, EQ(inc, I(1))) + return root + +def IsInc(p): + return p.name != "incx" and p.name != "incy" + +def Identity(p): + return True + +def FilterParams(params, func): + return [p for p in params if func(p)] + +def StrideAllIndexedTerms(stmts, var, itor, inc): + def is_hit(x): + if isinstance(x, Index): + if isinstance(x.i, BinaryOp) and x.x == var: + return x.i.l == itor + + return False + + terms = [] + for stmt in stmts: + Visit(stmt, lambda node: Filter(node, is_hit, terms)) + + for term in terms: + term.i = Mul(Paren(term.i), inc) + +def AsStrided(stmts, var, itor, inc): + def increment(x): + if isinstance(x, Index): + if isinstance(x.i, BinaryOp) and x.x == var: + return Index(x.x, Mul(Paren(x.i), inc)) + + return copy.copy(x) + + if isinstance(stmts, Block): + return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts]) + elif isinstance(stmts, list): + return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts] + else: + raise TypeError("unrecognized stmts type") + +class Iter(object): + def __init__(self, it, end, len, inc): + self.it = it + self.end = end + self.len = len + self.inc = inc + +def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]): + def Step(it, inc): + if inc == 1: + return Inc(it) + else: + return AddSet(it, I(inc)) + + return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + Block(*[ + *[func(i) for func in Preamble for i in range(top.inc)], + For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) + ), + For(None, LT(bot.it, bot.len), Inc(bot.it), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)]) + ), + *[func(i) for func in Postamble for i in range(bot.inc)] + ]) + ) + +def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True): + def Step(it, inc): + if inc == 1: + return Inc(it) + else: + return AddSet(it, I(inc)) + + def Finish(j): + if j == 0: + return For(None, LE(bot.it, top.it), Inc(bot.it), + Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)]) + ) + else: + return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) + + def Start(j, end): + if j == end: + return For(None, LT(bot.it, bot.end), Inc(bot.it), + Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)]) + ) + else: + return Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it)) + + if upper: + return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + Block(*[ + *[func(i) for func in Preamble for i in range(top.inc)], + Set(bot.end, Add(Paren(EvenTo(Paren(Sub(top.end, top.it)), bot.inc)), top.it)), + Set(bot.it, top.it), + *[ Start(j, top.inc-1) for j in range(top.inc) if bot.inc > 1], + For(None, LT(bot.it, bot.len), Step(bot.it, bot.inc), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) + ), + *[func(i) for func in Postamble for i in range(bot.inc)] + ]) + ) + else: + return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc), + Block(*[ + *[func(i) for func in Preamble for i in range(top.inc)], + Set(bot.end, EvenTo(top.it, bot.inc)), + For(Set(bot.it, I(0)), LE(bot.it, bot.end), Step(bot.it, bot.inc), + Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)]) + ), + *[ Finish(j) for j in range(top.inc) if bot.inc > 1], + *[func(i) for func in Postamble for i in range(bot.inc)] + ]) + ) + + + +def ToKernel(name, loop): + vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \ + VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body) + +# def ExpandAdd(i: int, c: Emitter, inc: int): +# offset = Add(c, I(0)) +# root = Mul(Index(Index(row, I(i)), offset), Index(x, offset)) +# for n in range(1, inc): +# offset = Add(c, I(n)) +# root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset))) +# return root + +# ------------------------------------------------------------------------ +# Blas level 2 functions + +def trsv(kind): + name = typeify("trsv", kind) + F = Func(pkg(name), Void, + Params( + (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx") + ), + Vars( + (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") + ) + ) + + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") + incx = F.variables("incx") + + rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) + + template = lambda inc_r, inc_c: TriangularLoop(rows(inc_r), cols(inc_c), + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + upper = True + ) + + loop = template(1, 1) + loop.emit() + +def syr(kind): + name = typeify("syr", kind) + F = Func(pkg(name), Void, + Params( + (UInt32, "flag"), (Int, "len"), (kind, "a"), + (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"), + ), + Vars( + (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") + ) + ) + + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x") + incx = F.variables("incx") + + rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c) + + template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c), + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + upper = upper == "upper" + ) + + blocks = [] + for layout in ["lower", "upper"]: + floop = template(1, 1, layout) + sloop = template(1, 1, layout) + sloop.body = AsStrided(sloop.body, x, c, incx) + + fini = template(1, 1, layout) + fini.init = None + fini.body = AsStrided(fini.body, x, c, incx) + fini.cond = LT(r, _len) + + blocks.append( + Block( + If(UnitIncs(incx), Block(floop), Block(sloop)), + fini, + Return(), + ) + ) + F.execute(If(flag, blocks[0], blocks[1])) + F.emit() + +def ger(kind): + name = typeify("ger", kind) + F = Func(pkg(name), Void, + Params( + (Int, "nrow"), (Int, "ncol"), (kind, "a"), + (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"), + ), + Vars( + (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res") + ) + ) + + r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res") + nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y") + incx, incy = F.variables("incx", "incy") + + rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc) + + template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc), + Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)], + Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))), + lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))], + ) + + # loop = template(1, 1) + # F.execute(loop) + # F.emit() + floop = template(ROW, COL) + sloop = template(ROW, COL) + sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy) + + fini = template(1, 2*COL) + fini.init = None + fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy) + fini.cond = LT(r, nrow) + + F.execute( + Set(nr, EvenTo(nrow, ROW)), + Set(nc, EvenTo(ncol, COL)), + If(UnitIncs(incx, incy), Block(floop), Block(sloop)) + ) + F.execute(fini) + F.emit() + +def gemv(kind): + name = typeify("gemv", kind) + params = Params( + (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"), + (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy") + ) + stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res")) + F = Func(pkg(name), Void, params, stack) + + # --------------------- + # Kernel + + def innerloop(rinc, cit, cend, cinc): + return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)), + Block(*[AddSet(TryIndex(res, I(i)), + AddElts(*(Mul( + Index(TryIndex(row, I(i)), Add(cit, I(j))), + Index(x, Add(cit, I(j)))) for j in range(cinc) + ) + ) + ) for i in range(rinc)]) + ) + + def tryloop(rinc, cit, cend, cinc): + if cinc > 1: + loop = innerloop(rinc, cit, cend, 1) + loop.init = None + return loop + + def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res): + return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)), + Block( + *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)], + *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)], + innerloop(rinc, cit, cend, cinc), + tryloop(rinc, cit, clen, cinc), + *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)] + ) + ) + + kerns = [] + for func, sfx in [(IsInc, ""), (Identity, "_s")]: + kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True) + r, c, row, res = kern.variables("r", "c", "row", "res") + nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") + + ncolr = kern.declare(Var(Int, "ncolr")) + loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res) + + kern.execute(Round(ncolr, ncol, COL)) + kern.execute(loop) + if "_s" in sfx: + incx, incy = kern.variables("incx", "incy") + StrideAllIndexedTerms(kern.stmts, x, c, incx) + StrideAllIndexedTerms(kern.stmts, y, r, incy) + + kern.emit() + + kerns.append(kern) + + r, c, row, res = F.variables("r", "c", "row", "res") + nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y") + incx, incy = F.variables("incx", "incy") + F.execute(Round(r, nrow, ROW)) + F.execute( + If(UnitIncs(incx, incy), + Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])), + Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])), + ) + ) + + F.params = Params((UInt32, "flag")) + F.params + + remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res) + remainder.init = None + F.execute(remainder) + StrideAllIndexedTerms(F, x, c, incx) + StrideAllIndexedTerms(F, y, r, incy) + + F.emit() + +# ------------------------------------------------------------------------ +# Code Generation + +if __name__ == "__main__": + emit("#include \n") + emit("#include \n") + emit("#include \n") + emitln() + emit("/*********************************************************/\n") + emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n") + emit("/*********************************************************/\n") + emitln(2) + + for kind in [Float64]: #[Float32, Float64]: + trsv(kind) + # syr(kind) + # ger(kind) + # gemv(kind) + + flush() -- cgit v1.2.1