aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorNicholas Noll <nbnoll@eml.cc>2020-05-14 18:15:23 -0700
committerNicholas Noll <nbnoll@eml.cc>2020-05-14 18:15:23 -0700
commit463ed852261da4d1dd1b859fa717a1d683306c9d (patch)
treed832d08663dcb53f86073d4fbb1609712fe5513f
parentd982e7c2fdebf560ccce193cb98b85d4fac28a45 (diff)
feat: begun work on final blas level 2
-rw-r--r--include/libmath/blas.h79
-rw-r--r--sys/libmath/blas.c27
-rw-r--r--sys/libmath/blas1.c214
-rwxr-xr-xsys/libmath/gen1.py5
-rwxr-xr-xsys/libmath/gen2.py390
5 files changed, 607 insertions, 108 deletions
diff --git a/include/libmath/blas.h b/include/libmath/blas.h
new file mode 100644
index 0000000..b331a8d
--- /dev/null
+++ b/include/libmath/blas.h
@@ -0,0 +1,79 @@
+#pragma once
+
+// TODO: think of better names
+enum
+{
+ blas·LowerTri = 1u,
+ blas·Transpose = 2u,
+ blas·ConjTranspose = 4u,
+ blas·DiagOnes = 8u,
+ blas·LeftSide = 16u,
+};
+
+typedef uint32 blas·Flag;
+
+/*
+ * Floats
+ */
+
+// level 1
+void blas·rotf(int len, float *x, int incx, float *y, int incy, float cos, float sin);
+void blas·rotgf(float *a, float *b, float *cos, float *sin);
+error blas·rotmf(int len, float *x, int incx, float *y, int incy, float p[5]);
+void blas·scalef(int len, float a, float *x, int inc);
+void blas·copyf(int len, float *x, int incx, float *y, int incy);
+void blas·swapf(int len, float *x, int incx, float *y, int incy);
+void blas·axpyf(int len, float a, float *x, int incx, float *y, int incy);
+float blas·dotf(int len, float *x, int incx, float *y, int incy);
+float blas·normf(int len, float *x, int inc);
+float blas·sumf(int len, float *x, int inc);
+int blas·argmaxf(int len, float *x, int inc);
+int blas·argminf(int len, float *x, int inc);
+
+// level 2
+void blas·tpmvf(blas·Flag f, int n, float *m, float *x);
+error blas·gemvf(int nrow, int ncol, float a, float *m, int incm, float *x, int incx, float b, float *y, int incy) ;
+void blas·tpsvf(blas·Flag f, int n, float *m, float *x);
+void blas·gerf(int nrow, int ncol, float a, float *x, int incx, float *y, int incy, float *m, int incm);
+void blas·herf(int n, float a, float *x, float *m);
+void blas·syrf(int nrow, int ncol, float a, float *x, float *m);
+
+// level 3
+void blas·gemmf(int n1, int n2, int n3, float a, float *m1, float *m2, float b, float *m3);
+void blas·trmmf(blas·Flag f, int nrow, int ncol, float a, float *m1, float *m2);
+void blas·trsmf(blas·Flag f, int nrow, int ncol, float a, float *m1, float *m2);
+
+/*
+ * Doubles
+ */
+
+// level 1
+void blas·rotd(int len, double *x, int incx, double *y, int incy, double cos, double sin);
+void blas·rotgd(double *a, double *b, double *cos, double *sin);
+error blas·rotmd(int len, double *x, int incx, double *y, int incy, double p[5]);
+void blas·scaled(int len, double a, double *x, int inc);
+void blas·copyd(int len, double *x, int incx, double *y, int incy);
+void blas·swapd(int len, double *x, int incx, double *y, int incy);
+void blas·axpyd(int len, double a, double *x, int incx, double *y, int incy);
+double blas·dotd(int len, double *x, int incx, double *y, int incy);
+double blas·normd(int len, double *x, int inc);
+double blas·sumd(int len, double *x, int inc);
+int blas·argmaxd(int len, double *x, int inc);
+int blas·argmind(int len, double *x, int inc);
+
+// level 2
+void blas·tpmvd(blas·Flag f, int n, double *m, double *x);
+error blas·gemvd(int nrow, int ncol, double a, double *m, int incm, double *x, int incx, double b, double *y, int incy) ;
+void blas·tpsvd(blas·Flag f, int n, double *m, double *x);
+void blas·gerd(int nrow, int ncol, double a, double *x, int incx, double *y, int incy, double *m, int incm);
+void blas·herd(int n, double a, double *x, double *m);
+void blas·syrd(int nrow, int ncol, double a, double *x, double *m);
+
+// level 3
+void blas·gemmd(int n1, int n2, int n3, double a, double *m1, double *m2, double b, double *m3);
+void blas·trmmd(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2);
+void blas·trsmd(blas·Flag f, int nrow, int ncol, double a, double *m1, double *m2);
+
+/*
+ * TODO: Complex
+ */
diff --git a/sys/libmath/blas.c b/sys/libmath/blas.c
index f5392c8..43b0e70 100644
--- a/sys/libmath/blas.c
+++ b/sys/libmath/blas.c
@@ -6,37 +6,44 @@
#include <vendor/blas/cblas.h>
-#define LEN 2000000
-#define NIT 2000
-#define INC 2
+#define NCOL 2000
+#define NROW 2000
+#define NIT 2000
+#define INC 2
error
main()
{
- int i, nit;
- double *x, *y, res[2];
+ int i, j, nit;
+ double *x, *y, *m, res[2];
clock_t t;
double tprof[2] = { 0 };
rng·init(0);
- x = malloc(sizeof(*x)*LEN);
- y = malloc(sizeof(*x)*LEN);
+ x = malloc(sizeof(*x)*NCOL);
+ y = malloc(sizeof(*x)*NROW);
+ m = malloc(sizeof(*x)*NROW*NCOL);
#define DO_0 t = clock(); \
- res[0] += blas·sumd(LEN/INC, x, INC); \
+ blas·gerd(NROW/INC, NCOL/INC, 1.2, x, INC, y, INC, m, NCOL); \
+ res[0] += m[0]; \
t = clock() - t; \
tprof[0] += 1000.*t/CLOCKS_PER_SEC; \
#define DO_1 t = clock(); \
- res[1] += cblas_dsum(LEN/INC, x, INC); \
+ cblas_dger(CblasRowMajor, NROW/INC, NCOL/INC, 1.2, x, INC, y, INC, m, NCOL); \
+ res[1] += m[0]; \
t = clock() - t; \
tprof[1] += 1000.*t/CLOCKS_PER_SEC;
for (nit = 0; nit < NIT; nit++) {
- for (i = 0; i < LEN; i++) {
+ for (i = 0; i < NROW; i++) {
x[i] = rng·random();
y[i] = rng·random();
+ for (j = 0; j < NCOL; j++) {
+ m[j + NCOL*i] = rng·random();
+ }
}
switch (nit % 2) {
diff --git a/sys/libmath/blas1.c b/sys/libmath/blas1.c
index 8cc70eb..1907179 100644
--- a/sys/libmath/blas1.c
+++ b/sys/libmath/blas1.c
@@ -173,7 +173,12 @@ blas·argminf(int len, float *x, int incx)
i = len;
ix = argminf_s_kernel8(&i, incx, x);
}
- min = x[ix];
+ for (; i < len; i++) {
+ if (x[incx * i] < min) {
+ ix = i;
+ min = x[incx * ix];
+ }
+ }
for (; i < len; i++) {
if (x[i] < min) {
ix = i;
@@ -188,16 +193,16 @@ static
int
argmaxf_kernel16(int *ip, float *x)
{
+ int ix[16];
float max[16];
int i;
- int ix[16];
int len;
for (i = 0; i < 16; i++) {
- max[i] = x[0];
+ ix[i] = 0;
}
for (i = 0; i < 16; i++) {
- ix[i] = 0;
+ max[i] = x[0];
}
len = *ip & ~15;
for (i = 0; i < len; i += 16) {
@@ -280,16 +285,16 @@ static
int
argmaxf_s_kernel8(int *ip, int incx, float *x)
{
+ int ix[8];
float max[8];
int i;
- int ix[8];
int len;
for (i = 0; i < 8; i++) {
- max[i] = x[0];
+ ix[i] = 0;
}
for (i = 0; i < 8; i++) {
- ix[i] = 0;
+ max[i] = x[0];
}
len = *ip & ~7;
for (i = 0; i < len; i += 8) {
@@ -350,7 +355,12 @@ blas·argmaxf(int len, float *x, int incx)
i = len;
ix = argmaxf_s_kernel8(&i, incx, x);
}
- max = x[ix];
+ for (; i < len; i++) {
+ if (x[incx * i] > max) {
+ ix = i;
+ max = x[incx * ix];
+ }
+ }
for (; i < len; i++) {
if (x[i] > max) {
ix = i;
@@ -392,7 +402,7 @@ copyf_kernel16(int *ip, float *x, float *y)
static
void
-copyf_s_kernel8(int *ip, int incx, float *x, int incy, float *y)
+copyf_s_kernel8(int *ip, int incx, int incy, float *x, float *y)
{
int i;
int len;
@@ -421,16 +431,16 @@ blas·copyf(int len, float *x, int incx, float *y, int incy)
copyf_kernel16(&i, x, y);
} else {
i = len;
- copyf_s_kernel8(&i, incx, x, incy, y);
+ copyf_s_kernel8(&i, incx, incy, x, y);
}
for (; i < len; i++) {
- y[i] = x[i];
+ y[incy * i] = x[incx * i];
}
}
static
void
-axpyf_kernel16(int *ip, float a, float *x, float *y)
+axpyf_kernel16(int *ip, float *x, float a, float *y)
{
int i;
int len;
@@ -459,7 +469,7 @@ axpyf_kernel16(int *ip, float a, float *x, float *y)
static
void
-axpyf_s_kernel8(int *ip, float a, float *x, int incy, float *y, int incx)
+axpyf_s_kernel8(int *ip, float *x, float a, int incx, float *y, int incy)
{
int i;
int len;
@@ -485,19 +495,19 @@ blas·axpyf(int len, float a, float *x, int incx, float *y, int incy)
if (incx == 1 && incy == 1) {
i = len;
- axpyf_kernel16(&i, a, x, y);
+ axpyf_kernel16(&i, x, a, y);
} else {
i = len;
- axpyf_s_kernel8(&i, a, x, incy, y, incx);
+ axpyf_s_kernel8(&i, x, a, incx, y, incy);
}
for (; i < len; i++) {
- y[i] = y[i] + a * x[i];
+ y[incy * i] = y[incy * i] + a * x[incx * i];
}
}
static
void
-axpbyf_kernel16(int *ip, float *x, float a, float *y, float b)
+axpbyf_kernel16(int *ip, float a, float *y, float b, float *x)
{
int i;
int len;
@@ -526,7 +536,7 @@ axpbyf_kernel16(int *ip, float *x, float a, float *y, float b)
static
void
-axpbyf_s_kernel8(int *ip, float *x, float a, float *y, float b, int incx, int incy)
+axpbyf_s_kernel8(int *ip, float a, int incx, float *y, float b, int incy, float *x)
{
int i;
int len;
@@ -552,22 +562,22 @@ blas·axpbyf(int len, float a, float *x, int incx, float b, float *y, int incy)
if (incx == 1 && incy == 1) {
i = len;
- axpbyf_kernel16(&i, x, a, y, b);
+ axpbyf_kernel16(&i, a, y, b, x);
} else {
i = len;
- axpbyf_s_kernel8(&i, x, a, y, b, incx, incy);
+ axpbyf_s_kernel8(&i, a, incx, y, b, incy, x);
}
for (; i < len; i++) {
- y[i] = b * y[i] + a * x[i];
+ y[incy * i] = b * y[incy * i] + a * x[incx * i];
}
}
static
float
-dotf_kernel16(int *ip, float *x, float *y)
+dotf_kernel16(int *ip, float *y, float *x)
{
- int i;
float sum[16];
+ int i;
int len;
for (i = 0; i < 16; i++) {
@@ -602,10 +612,10 @@ dotf_kernel16(int *ip, float *x, float *y)
static
float
-dotf_s_kernel8(int *ip, float *x, int incy, float *y, int incx)
+dotf_s_kernel8(int *ip, int incy, float *x, float *y, int incx)
{
- int i;
float sum[8];
+ int i;
int len;
for (i = 0; i < 8; i++) {
@@ -638,13 +648,13 @@ blas·dotf(int len, float *x, int incx, float *y, int incy)
if (incx == 1 && incy == 1) {
i = len;
- sum = dotf_kernel16(&i, x, y);
+ sum = dotf_kernel16(&i, y, x);
} else {
i = len;
- sum = dotf_s_kernel8(&i, x, incy, y, incx);
+ sum = dotf_s_kernel8(&i, incy, x, y, incx);
}
for (; i < len; i++) {
- sum += x[i] * y[i];
+ sum += x[incx * i] * y[incy * i];
}
return sum;
@@ -654,8 +664,8 @@ static
float
sumf_kernel16(int *ip, float *x)
{
- int i;
float sum[16];
+ int i;
int len;
for (i = 0; i < 16; i++) {
@@ -690,7 +700,7 @@ sumf_kernel16(int *ip, float *x)
static
float
-sumf_s_kernel8(int *ip, float *x, int incx)
+sumf_s_kernel8(int *ip, int incx, float *x)
{
float sum[8];
int i;
@@ -729,10 +739,10 @@ blas·sumf(int len, float *x, int incx)
sum = sumf_kernel16(&i, x);
} else {
i = len;
- sum = sumf_s_kernel8(&i, x, incx);
+ sum = sumf_s_kernel8(&i, incx, x);
}
for (; i < len; i++) {
- sum += x[i];
+ sum += x[incx * i];
}
return sum;
@@ -742,8 +752,8 @@ static
float
normf_kernel16(int *ip, float *x)
{
- float nrm[16];
int i;
+ float nrm[16];
int len;
for (i = 0; i < 16; i++) {
@@ -780,8 +790,8 @@ static
float
normf_s_kernel8(int *ip, int incx, float *x)
{
- float nrm[8];
int i;
+ float nrm[8];
int len;
for (i = 0; i < 8; i++) {
@@ -820,7 +830,7 @@ blas·normf(int len, float *x, int incx)
nrm = normf_s_kernel8(&i, incx, x);
}
for (; i < len; i++) {
- nrm += x[i] * x[i];
+ nrm += x[incx * i] * x[incx * i];
}
return math·sqrtf(nrm);
@@ -857,7 +867,7 @@ scalef_kernel16(int *ip, float *x, float a)
static
void
-scalef_s_kernel8(int *ip, int incx, float *x, float a)
+scalef_s_kernel8(int *ip, float *x, int incx, float a)
{
int i;
int len;
@@ -886,16 +896,16 @@ blas·scalef(int len, float *x, int incx, float a)
scalef_kernel16(&i, x, a);
} else {
i = len;
- scalef_s_kernel8(&i, incx, x, a);
+ scalef_s_kernel8(&i, x, incx, a);
}
for (; i < len; i++) {
- x[i] = a * x[i];
+ x[incx * i] = a * x[incx * i];
}
}
static
void
-rotf_kernel16(int *ip, float *y, float cos, float sin, float *x)
+rotf_kernel16(int *ip, float sin, float *x, float cos, float *y)
{
int i;
float tmp[16];
@@ -925,7 +935,7 @@ rotf_kernel16(int *ip, float *y, float cos, float sin, float *x)
static
void
-rotf_s_kernel8(int *ip, float *y, float sin, float cos, float *x, int incx, int incy)
+rotf_s_kernel8(int *ip, int incy, float cos, float sin, int incx, float *x, float *y)
{
int i;
float tmp[8];
@@ -953,19 +963,19 @@ blas·rotf(int len, float *x, int incx, float *y, int incy, float cos, float sin
if (incx == 1 && incy == 1) {
i = len;
- rotf_kernel16(&i, y, cos, sin, x);
+ rotf_kernel16(&i, sin, x, cos, y);
} else {
i = len;
- rotf_s_kernel8(&i, y, sin, cos, x, incx, incy);
+ rotf_s_kernel8(&i, incy, cos, sin, incx, x, y);
}
for (; i < len; i++) {
- tmp = x[i], x[i] = cos * x[i] + sin * y[i], y[i] = cos * y[i] - sin * tmp;
+ tmp = x[incx * i], x[incx * i] = cos * x[incx * i] + sin * y[incy * i], y[incy * i] = cos * y[incy * i] - sin * tmp;
}
}
static
void
-rotgf_kernel16(int *ip, float H[5], float *y, float *x)
+rotgf_kernel16(int *ip, float *x, float H[5], float *y)
{
float tmp[16];
int i;
@@ -995,7 +1005,7 @@ rotgf_kernel16(int *ip, float H[5], float *y, float *x)
static
void
-rotgf_s_kernel8(int *ip, int incy, int incx, float H[5], float *y, float *x)
+rotgf_s_kernel8(int *ip, int incx, int incy, float *x, float H[5], float *y)
{
float tmp[8];
int i;
@@ -1023,13 +1033,13 @@ blas·rotgf(int len, float *x, int incx, float *y, int incy, float H[5])
if (incx == 1 && incy == 1) {
i = len;
- rotgf_kernel16(&i, H, y, x);
+ rotgf_kernel16(&i, x, H, y);
} else {
i = len;
- rotgf_s_kernel8(&i, incy, incx, H, y, x);
+ rotgf_s_kernel8(&i, incx, incy, x, H, y);
}
for (; i < len; i++) {
- tmp = x[i], x[i] = H[1] * x[i] + H[2] * y[i], y[i] = H[3] * y[i] + H[4] * tmp;
+ tmp = x[incx * i], x[incx * i] = H[1] * x[incx * i] + H[2] * y[incy * i], y[incy * i] = H[3] * y[incy * i] + H[4] * tmp;
}
}
@@ -1037,16 +1047,16 @@ static
int
argmind_kernel16(int *ip, double *x)
{
- int ix[16];
double min[16];
int i;
+ int ix[16];
int len;
for (i = 0; i < 16; i++) {
- ix[i] = 0;
+ min[i] = x[0];
}
for (i = 0; i < 16; i++) {
- min[i] = x[0];
+ ix[i] = 0;
}
len = *ip & ~15;
for (i = 0; i < len; i += 16) {
@@ -1129,16 +1139,16 @@ static
int
argmind_s_kernel8(int *ip, double *x, int incx)
{
- int ix[8];
double min[8];
int i;
+ int ix[8];
int len;
for (i = 0; i < 8; i++) {
- ix[i] = 0;
+ min[i] = x[0];
}
for (i = 0; i < 8; i++) {
- min[i] = x[0];
+ ix[i] = 0;
}
len = *ip & ~7;
for (i = 0; i < len; i += 8) {
@@ -1199,7 +1209,12 @@ blas·argmind(int len, double *x, int incx)
i = len;
ix = argmind_s_kernel8(&i, x, incx);
}
- min = x[ix];
+ for (; i < len; i++) {
+ if (x[incx * i] < min) {
+ ix = i;
+ min = x[incx * ix];
+ }
+ }
for (; i < len; i++) {
if (x[i] < min) {
ix = i;
@@ -1214,8 +1229,8 @@ static
int
argmaxd_kernel16(int *ip, double *x)
{
- int i;
int ix[16];
+ int i;
double max[16];
int len;
@@ -1306,8 +1321,8 @@ static
int
argmaxd_s_kernel8(int *ip, int incx, double *x)
{
- int i;
int ix[8];
+ int i;
double max[8];
int len;
@@ -1376,7 +1391,12 @@ blas·argmaxd(int len, double *x, int incx)
i = len;
ix = argmaxd_s_kernel8(&i, incx, x);
}
- max = x[ix];
+ for (; i < len; i++) {
+ if (x[incx * i] > max) {
+ ix = i;
+ max = x[incx * ix];
+ }
+ }
for (; i < len; i++) {
if (x[i] > max) {
ix = i;
@@ -1418,7 +1438,7 @@ copyd_kernel16(int *ip, double *y, double *x)
static
void
-copyd_s_kernel8(int *ip, double *y, double *x, int incy, int incx)
+copyd_s_kernel8(int *ip, double *y, int incx, int incy, double *x)
{
int i;
int len;
@@ -1447,16 +1467,16 @@ blas·copyd(int len, double *x, int incx, double *y, int incy)
copyd_kernel16(&i, y, x);
} else {
i = len;
- copyd_s_kernel8(&i, y, x, incy, incx);
+ copyd_s_kernel8(&i, y, incx, incy, x);
}
for (; i < len; i++) {
- y[i] = x[i];
+ y[incy * i] = x[incx * i];
}
}
static
void
-axpyd_kernel16(int *ip, double a, double *y, double *x)
+axpyd_kernel16(int *ip, double *x, double *y, double a)
{
int i;
int len;
@@ -1485,7 +1505,7 @@ axpyd_kernel16(int *ip, double a, double *y, double *x)
static
void
-axpyd_s_kernel8(int *ip, double a, int incx, double *y, double *x, int incy)
+axpyd_s_kernel8(int *ip, double *x, int incy, double *y, double a, int incx)
{
int i;
int len;
@@ -1511,13 +1531,13 @@ blas·axpyd(int len, double a, double *x, int incx, double *y, int incy)
if (incx == 1 && incy == 1) {
i = len;
- axpyd_kernel16(&i, a, y, x);
+ axpyd_kernel16(&i, x, y, a);
} else {
i = len;
- axpyd_s_kernel8(&i, a, incx, y, x, incy);
+ axpyd_s_kernel8(&i, x, incy, y, a, incx);
}
for (; i < len; i++) {
- y[i] = y[i] + a * x[i];
+ y[incy * i] = y[incy * i] + a * x[incx * i];
}
}
@@ -1552,7 +1572,7 @@ axpbyd_kernel16(int *ip, double a, double b, double *x, double *y)
static
void
-axpbyd_s_kernel8(int *ip, double a, double b, double *x, int incx, int incy, double *y)
+axpbyd_s_kernel8(int *ip, double a, int incy, double b, double *x, int incx, double *y)
{
int i;
int len;
@@ -1581,16 +1601,16 @@ blas·axpbyd(int len, double a, double *x, int incx, double b, double *y, int in
axpbyd_kernel16(&i, a, b, x, y);
} else {
i = len;
- axpbyd_s_kernel8(&i, a, b, x, incx, incy, y);
+ axpbyd_s_kernel8(&i, a, incy, b, x, incx, y);
}
for (; i < len; i++) {
- y[i] = b * y[i] + a * x[i];
+ y[incy * i] = b * y[incy * i] + a * x[incx * i];
}
}
static
double
-dotd_kernel16(int *ip, double *x, double *y)
+dotd_kernel16(int *ip, double *y, double *x)
{
double sum[16];
int i;
@@ -1628,7 +1648,7 @@ dotd_kernel16(int *ip, double *x, double *y)
static
double
-dotd_s_kernel8(int *ip, int incx, double *x, int incy, double *y)
+dotd_s_kernel8(int *ip, double *y, int incx, int incy, double *x)
{
double sum[8];
int i;
@@ -1664,13 +1684,13 @@ blas·dotd(int len, double *x, int incx, double *y, int incy)
if (incx == 1 && incy == 1) {
i = len;
- sum = dotd_kernel16(&i, x, y);
+ sum = dotd_kernel16(&i, y, x);
} else {
i = len;
- sum = dotd_s_kernel8(&i, incx, x, incy, y);
+ sum = dotd_s_kernel8(&i, y, incx, incy, x);
}
for (; i < len; i++) {
- sum += x[i] * y[i];
+ sum += x[incx * i] * y[incy * i];
}
return sum;
@@ -1716,7 +1736,7 @@ sumd_kernel16(int *ip, double *x)
static
double
-sumd_s_kernel8(int *ip, double *x, int incx)
+sumd_s_kernel8(int *ip, int incx, double *x)
{
int i;
double sum[8];
@@ -1755,10 +1775,10 @@ blas·sumd(int len, double *x, int incx)
sum = sumd_kernel16(&i, x);
} else {
i = len;
- sum = sumd_s_kernel8(&i, x, incx);
+ sum = sumd_s_kernel8(&i, incx, x);
}
for (; i < len; i++) {
- sum += x[i];
+ sum += x[incx * i];
}
return sum;
@@ -1768,8 +1788,8 @@ static
double
normd_kernel16(int *ip, double *x)
{
- double nrm[16];
int i;
+ double nrm[16];
int len;
for (i = 0; i < 16; i++) {
@@ -1806,8 +1826,8 @@ static
double
normd_s_kernel8(int *ip, int incx, double *x)
{
- int i;
double nrm[8];
+ int i;
int len;
for (i = 0; i < 8; i++) {
@@ -1846,7 +1866,7 @@ blas·normd(int len, double *x, int incx)
nrm = normd_s_kernel8(&i, incx, x);
}
for (; i < len; i++) {
- nrm += x[i] * x[i];
+ nrm += x[incx * i] * x[incx * i];
}
return math·sqrt(nrm);
@@ -1883,7 +1903,7 @@ scaled_kernel16(int *ip, double a, double *x)
static
void
-scaled_s_kernel8(int *ip, double a, int incx, double *x)
+scaled_s_kernel8(int *ip, double *x, int incx, double a)
{
int i;
int len;
@@ -1912,19 +1932,19 @@ blas·scaled(int len, double *x, int incx, double a)
scaled_kernel16(&i, a, x);
} else {
i = len;
- scaled_s_kernel8(&i, a, incx, x);
+ scaled_s_kernel8(&i, x, incx, a);
}
for (; i < len; i++) {
- x[i] = a * x[i];
+ x[incx * i] = a * x[incx * i];
}
}
static
void
-rotd_kernel16(int *ip, double *x, double *y, double cos, double sin)
+rotd_kernel16(int *ip, double sin, double cos, double *x, double *y)
{
- int i;
double tmp[16];
+ int i;
int len;
len = *ip & ~15;
@@ -1951,10 +1971,10 @@ rotd_kernel16(int *ip, double *x, double *y, double cos, double sin)
static
void
-rotd_s_kernel8(int *ip, double *y, int incx, double sin, int incy, double *x, double cos)
+rotd_s_kernel8(int *ip, double sin, double *x, double *y, double cos, int incy, int incx)
{
- int i;
double tmp[8];
+ int i;
int len;
len = *ip & ~7;
@@ -1979,22 +1999,22 @@ blas·rotd(int len, double *x, int incx, double *y, int incy, double cos, double
if (incx == 1 && incy == 1) {
i = len;
- rotd_kernel16(&i, x, y, cos, sin);
+ rotd_kernel16(&i, sin, cos, x, y);
} else {
i = len;
- rotd_s_kernel8(&i, y, incx, sin, incy, x, cos);
+ rotd_s_kernel8(&i, sin, x, y, cos, incy, incx);
}
for (; i < len; i++) {
- tmp = x[i], x[i] = cos * x[i] + sin * y[i], y[i] = cos * y[i] - sin * tmp;
+ tmp = x[incx * i], x[incx * i] = cos * x[incx * i] + sin * y[incy * i], y[incy * i] = cos * y[incy * i] - sin * tmp;
}
}
static
void
-rotgd_kernel16(int *ip, double *y, double H[5], double *x)
+rotgd_kernel16(int *ip, double H[5], double *y, double *x)
{
- double tmp[16];
int i;
+ double tmp[16];
int len;
len = *ip & ~15;
@@ -2021,10 +2041,10 @@ rotgd_kernel16(int *ip, double *y, double H[5], double *x)
static
void
-rotgd_s_kernel8(int *ip, int incx, double *y, int incy, double H[5], double *x)
+rotgd_s_kernel8(int *ip, double H[5], int incx, double *y, double *x, int incy)
{
- double tmp[8];
int i;
+ double tmp[8];
int len;
len = *ip & ~7;
@@ -2049,13 +2069,13 @@ blas·rotgd(int len, double *x, int incx, double *y, int incy, double H[5])
if (incx == 1 && incy == 1) {
i = len;
- rotgd_kernel16(&i, y, H, x);
+ rotgd_kernel16(&i, H, y, x);
} else {
i = len;
- rotgd_s_kernel8(&i, incx, y, incy, H, x);
+ rotgd_s_kernel8(&i, H, incx, y, x, incy);
}
for (; i < len; i++) {
- tmp = x[i], x[i] = H[1] * x[i] + H[2] * y[i], y[i] = H[3] * y[i] + H[4] * tmp;
+ tmp = x[incx * i], x[incx * i] = H[1] * x[incx * i] + H[2] * y[incy * i], y[incy * i] = H[3] * y[incy * i] + H[4] * tmp;
}
}
diff --git a/sys/libmath/gen1.py b/sys/libmath/gen1.py
index b0f9ecc..936bc50 100755
--- a/sys/libmath/gen1.py
+++ b/sys/libmath/gen1.py
@@ -1,3 +1,5 @@
+#!/bin/python
+
from C import *
NUNROLL = 16
@@ -12,7 +14,8 @@ def typeify(string, kind):
def fini(func, loop, strided, calls, ret=[]):
func.execute(*calls[:2])
- func = Strided(func, loop, NUNROLL//2, strided, *ret)
+ func, scall = Strided(func, loop, NUNROLL//2, strided, *ret)
+ calls[2] = scall[0]
func.execute(*calls[2:])
func.emit()
diff --git a/sys/libmath/gen2.py b/sys/libmath/gen2.py
new file mode 100755
index 0000000..6ce2a12
--- /dev/null
+++ b/sys/libmath/gen2.py
@@ -0,0 +1,390 @@
+from C import *
+import copy
+
+ROW = 4
+COL = 4
+
+def pkg(name):
+ return f"blas·{name}"
+
+def typeify(string, kind):
+ if (kind == Float32):
+ return f"{string}f"
+ if (kind == Float64):
+ return f"{string}d"
+
+# ------------------------------------------------------------------------
+# Helpers (abandoning the automatic unroll from level 1)
+
+def toarray(len: int, *args):
+ return [Param(Array(arg.type, len), arg.name) for arg in args]
+
+def TryIndex(x, i):
+ if IsArrayType(x.var.type):
+ return Index(x, i)
+ return x
+
+def AddElts(root, *vars):
+ for var in vars:
+ root = Add(root ,var)
+ return root
+
+def Round(store, number, by):
+ return Set(store, And(number, Negate(I(by-1))))
+
+def UnitIncs(root, *incs):
+ root = EQ(root, I(1))
+ for inc in incs:
+ root = AndAnd(root, EQ(inc, I(1)))
+ return root
+
+def IsInc(p):
+ return p.name != "incx" and p.name != "incy"
+
+def Identity(p):
+ return True
+
+def FilterParams(params, func):
+ return [p for p in params if func(p)]
+
+def StrideAllIndexedTerms(stmts, var, itor, inc):
+ def is_hit(x):
+ if isinstance(x, Index):
+ if isinstance(x.i, BinaryOp) and x.x == var:
+ return x.i.l == itor
+
+ return False
+
+ terms = []
+ for stmt in stmts:
+ Visit(stmt, lambda node: Filter(node, is_hit, terms))
+
+ for term in terms:
+ term.i = Mul(Paren(term.i), inc)
+
+def AsStrided(stmts, var, itor, inc):
+ def increment(x):
+ if isinstance(x, Index):
+ if isinstance(x.i, BinaryOp) and x.x == var:
+ return Index(x.x, Mul(Paren(x.i), inc))
+
+ return copy.copy(x)
+
+ if isinstance(stmts, Block):
+ return Block(*[Make(stmt, lambda node: Transform(node, increment)) for stmt in stmts.stmts])
+ elif isinstance(stmts, list):
+ return [Make(stmt, lambda node: Transform(node, increment(node))) for stmt in stmts]
+ else:
+ raise TypeError("unrecognized stmts type")
+
+class Iter(object):
+ def __init__(self, it, end, len, inc):
+ self.it = it
+ self.end = end
+ self.len = len
+ self.inc = inc
+
+def DoubleLoop(top, bot, Kernel, Preamble=[], Postamble=[]):
+ def Step(it, inc):
+ if inc == 1:
+ return Inc(it)
+ else:
+ return AddSet(it, I(inc))
+
+ return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
+ Block(*[
+ *[func(i) for func in Preamble for i in range(top.inc)],
+ For(Set(bot.it, I(0)), LT(bot.it, bot.end), Step(bot.it, bot.inc),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
+ ),
+ For(None, LT(bot.it, bot.len), Inc(bot.it),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, 1)])
+ ),
+ *[func(i) for func in Postamble for i in range(bot.inc)]
+ ])
+ )
+
+def TriangularLoop(top, bot, Kernel, Preamble=[], Postamble=[], upper=True):
+ def Step(it, inc):
+ if inc == 1:
+ return Inc(it)
+ else:
+ return AddSet(it, I(inc))
+
+ def Finish(j):
+ if j == 0:
+ return For(None, LE(bot.it, top.it), Inc(bot.it),
+ Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)])
+ )
+ else:
+ return Block(*[func for i in range(j, top.inc) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it))
+
+ def Start(j, end):
+ if j == end:
+ return For(None, LT(bot.it, bot.end), Inc(bot.it),
+ Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)])
+ )
+ else:
+ return Block(*[func for i in range(j+1) for func in Kernel(top.it, bot.it, i, 1)], Inc(bot.it))
+
+ if upper:
+ return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
+ Block(*[
+ *[func(i) for func in Preamble for i in range(top.inc)],
+ Set(bot.end, Add(Paren(EvenTo(Paren(Sub(top.end, top.it)), bot.inc)), top.it)),
+ Set(bot.it, top.it),
+ *[ Start(j, top.inc-1) for j in range(top.inc) if bot.inc > 1],
+ For(None, LT(bot.it, bot.len), Step(bot.it, bot.inc),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
+ ),
+ *[func(i) for func in Postamble for i in range(bot.inc)]
+ ])
+ )
+ else:
+ return For(Set(top.it, I(0)), LT(top.it, top.end), Step(top.it, top.inc),
+ Block(*[
+ *[func(i) for func in Preamble for i in range(top.inc)],
+ Set(bot.end, EvenTo(top.it, bot.inc)),
+ For(Set(bot.it, I(0)), LE(bot.it, bot.end), Step(bot.it, bot.inc),
+ Block(*[func for i in range(top.inc) for func in Kernel(top.it, bot.it, i, bot.inc)])
+ ),
+ *[ Finish(j) for j in range(top.inc) if bot.inc > 1],
+ *[func(i) for func in Postamble for i in range(bot.inc)]
+ ])
+ )
+
+
+
+def ToKernel(name, loop):
+ vars = VarsUsed(StmtExpr(loop.init)) | VarsUsed(StmtExpr(loop.cond)) | \
+ VarsUsed(StmtExpr(loop.step)) | VarsUsed(loop.body)
+
+# def ExpandAdd(i: int, c: Emitter, inc: int):
+# offset = Add(c, I(0))
+# root = Mul(Index(Index(row, I(i)), offset), Index(x, offset))
+# for n in range(1, inc):
+# offset = Add(c, I(n))
+# root = Add(root, Mul(Index(Index(row, I(i)), offset), Index(x, offset)))
+# return root
+
+# ------------------------------------------------------------------------
+# Blas level 2 functions
+
+def trsv(kind):
+ name = typeify("trsv", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (UInt32, "flag"), (Int, "len"), (Ptr(kind), "m"), (Int, "incm"), (Ptr(kind), "x"), (Int, "incx")
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
+ )
+ )
+
+ r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
+ flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x")
+ incx = F.variables("incx")
+
+ rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c)
+
+ template = lambda inc_r, inc_c: TriangularLoop(rows(inc_r), cols(inc_c),
+ Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)],
+ Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
+ lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
+ upper = True
+ )
+
+ loop = template(1, 1)
+ loop.emit()
+
+def syr(kind):
+ name = typeify("syr", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (UInt32, "flag"), (Int, "len"), (kind, "a"),
+ (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "m"), (Int, "incm"),
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
+ )
+ )
+
+ r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
+ flag, _len, a, m, incm, x = F.variables("flag", "len", "a", "m", "incm", "x")
+ incx = F.variables("incx")
+
+ rows, cols = lambda inc_r: Iter(r, nr, _len, inc_r), lambda inc_c: Iter(c, nc, _len, inc_c)
+
+ template = lambda inc_r, inc_c, upper: TriangularLoop(rows(inc_r), cols(inc_c),
+ Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(x, Add(c, I(j))))) for j in range(inc)],
+ Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
+ lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
+ upper = upper == "upper"
+ )
+
+ blocks = []
+ for layout in ["lower", "upper"]:
+ floop = template(1, 1, layout)
+ sloop = template(1, 1, layout)
+ sloop.body = AsStrided(sloop.body, x, c, incx)
+
+ fini = template(1, 1, layout)
+ fini.init = None
+ fini.body = AsStrided(fini.body, x, c, incx)
+ fini.cond = LT(r, _len)
+
+ blocks.append(
+ Block(
+ If(UnitIncs(incx), Block(floop), Block(sloop)),
+ fini,
+ Return(),
+ )
+ )
+ F.execute(If(flag, blocks[0], blocks[1]))
+ F.emit()
+
+def ger(kind):
+ name = typeify("ger", kind)
+ F = Func(pkg(name), Void,
+ Params(
+ (Int, "nrow"), (Int, "ncol"), (kind, "a"),
+ (Ptr(kind), "x"), (Int, "incx"), (Ptr(kind), "y"), (Int, "incy"), (Ptr(kind), "m"), (Int, "incm"),
+ ),
+ Vars(
+ (Int, "r"), (Int, "c"), (Int, "nr"), (Int, "nc"), (Array(Ptr(kind), ROW), "row"), (Array(kind, COL), "res")
+ )
+ )
+
+ r, c, nr, nc, row, res = F.variables("r", "c", "nr", "nc", "row", "res")
+ nrow, ncol, a, m, incm, x, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "y")
+ incx, incy = F.variables("incx", "incy")
+
+ rows, cols = lambda incr: Iter(r, nr, nrow, incr), lambda incc: Iter(c, nc, ncol, incc)
+
+ template = lambda incr, incc: DoubleLoop(rows(incr), cols(incc),
+ Kernel = lambda r, c, i, inc: [AddSet(Index(Index(row, I(i)), Add(c, I(j))), Mul(Index(res, I(i)), Index(y, Add(c, I(j))))) for j in range(inc)],
+ Preamble = [lambda i: Set(Index(row, I(i)), Add(m, Mul(Paren(Add(r, I(i))), incm))),
+ lambda i: Set(Index(res, I(i)), Mul(a, Index(x, Add(r, I(i)))))],
+ )
+
+ # loop = template(1, 1)
+ # F.execute(loop)
+ # F.emit()
+ floop = template(ROW, COL)
+ sloop = template(ROW, COL)
+ sloop.body = AsStrided(AsStrided(sloop.body, x, c, incx), y, r, incy)
+
+ fini = template(1, 2*COL)
+ fini.init = None
+ fini.body = AsStrided(AsStrided(fini.body, x, c, incx), y, r, incy)
+ fini.cond = LT(r, nrow)
+
+ F.execute(
+ Set(nr, EvenTo(nrow, ROW)),
+ Set(nc, EvenTo(ncol, COL)),
+ If(UnitIncs(incx, incy), Block(floop), Block(sloop))
+ )
+ F.execute(fini)
+ F.emit()
+
+def gemv(kind):
+ name = typeify("gemv", kind)
+ params = Params(
+ (Int, "nrow"), (Int, "ncol"), (kind, "a"), (Ptr(kind), "m"), (Int, "incm"),
+ (Ptr(kind), "x"), (Int, "incx"), (kind, "b"), (Ptr(kind), "y"), (Int, "incy")
+ )
+ stack = Vars((Int, "r"), (Int, "c"), (Ptr(kind), "row"), (kind, "res"))
+ F = Func(pkg(name), Void, params, stack)
+
+ # ---------------------
+ # Kernel
+
+ def innerloop(rinc, cit, cend, cinc):
+ return For(Set(cit, I(0)), LT(cit, cend), AddSet(cit, I(cinc)),
+ Block(*[AddSet(TryIndex(res, I(i)),
+ AddElts(*(Mul(
+ Index(TryIndex(row, I(i)), Add(cit, I(j))),
+ Index(x, Add(cit, I(j)))) for j in range(cinc)
+ )
+ )
+ ) for i in range(rinc)])
+ )
+
+ def tryloop(rinc, cit, cend, cinc):
+ if cinc > 1:
+ loop = innerloop(rinc, cit, cend, 1)
+ loop.init = None
+ return loop
+
+ def outerloop(rit, rlen, rinc, cit, cend, clen, cinc, row, res):
+ return For(Set(rit, I(0)), LT(rit, rlen), AddSet(r, I(rinc)),
+ Block(
+ *[Set(TryIndex(row, I(i)), Add(m, Mul(Paren(Add(rit, I(i))), incm))) for i in range(rinc)],
+ *[Set(TryIndex(res, I(i)), I(0)) for i in range(rinc)],
+ innerloop(rinc, cit, cend, cinc),
+ tryloop(rinc, cit, clen, cinc),
+ *[Set(Index(y, Add(rit, I(i))), Add(Mul(a, TryIndex(res, I(i))), Mul(b, Index(y, Add(rit, I(i)))))) for i in range(rinc)]
+ )
+ )
+
+ kerns = []
+ for func, sfx in [(IsInc, ""), (Identity, "_s")]:
+ kern = Func(f"{name}{sfx}_{ROW}x{COL}kernel", Void, FilterParams(params, func), stack[0:2] + toarray(ROW, *stack[2:]), static=True)
+ r, c, row, res = kern.variables("r", "c", "row", "res")
+ nrow, ncol, a, m, incm, x, b, y = kern.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
+
+ ncolr = kern.declare(Var(Int, "ncolr"))
+ loop = outerloop(r, nrow, ROW, c, ncolr, ncol, COL, row, res)
+
+ kern.execute(Round(ncolr, ncol, COL))
+ kern.execute(loop)
+ if "_s" in sfx:
+ incx, incy = kern.variables("incx", "incy")
+ StrideAllIndexedTerms(kern.stmts, x, c, incx)
+ StrideAllIndexedTerms(kern.stmts, y, r, incy)
+
+ kern.emit()
+
+ kerns.append(kern)
+
+ r, c, row, res = F.variables("r", "c", "row", "res")
+ nrow, ncol, a, m, incm, x, b, y = F.variables("nrow", "ncol", "a", "m", "incm", "x", "b", "y")
+ incx, incy = F.variables("incx", "incy")
+ F.execute(Round(r, nrow, ROW))
+ F.execute(
+ If(UnitIncs(incx, incy),
+ Block(Call(kerns[0], [r, ncol, a, m, incm, x, b, y])),
+ Block(Call(kerns[1], [r, ncol, a, m, incm, x, incx, b, y, incy])),
+ )
+ )
+
+ F.params = Params((UInt32, "flag")) + F.params
+
+ remainder = outerloop(r, nrow, 1, c, ncol, ncol, COL, row, res)
+ remainder.init = None
+ F.execute(remainder)
+ StrideAllIndexedTerms(F, x, c, incx)
+ StrideAllIndexedTerms(F, y, r, incy)
+
+ F.emit()
+
+# ------------------------------------------------------------------------
+# Code Generation
+
+if __name__ == "__main__":
+ emit("#include <u.h>\n")
+ emit("#include <libn.h>\n")
+ emit("#include <libmath.h>\n")
+ emitln()
+ emit("/*********************************************************/\n")
+ emit("/* THIS CODE IS GENERATED BY GEN2.PY! DON'T EDIT BY HAND */\n")
+ emit("/*********************************************************/\n")
+ emitln(2)
+
+ for kind in [Float64]: #[Float32, Float64]:
+ trsv(kind)
+ # syr(kind)
+ # ger(kind)
+ # gemv(kind)
+
+ flush()