aboutsummaryrefslogtreecommitdiff
path: root/src/libmath/blas3.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/libmath/blas3.c')
-rw-r--r--src/libmath/blas3.c279
1 files changed, 279 insertions, 0 deletions
diff --git a/src/libmath/blas3.c b/src/libmath/blas3.c
new file mode 100644
index 0000000..b048c95
--- /dev/null
+++ b/src/libmath/blas3.c
@@ -0,0 +1,279 @@
+#include <u.h>
+#include <base.h>
+#include <libmath.h>
+
+#define INT int
+#define FLOAT double
+#define func(name) blas·d##name
+
+#define X(i, j) x[j + incx*(i)]
+#define Y(i, j) y[j + incy*(i)]
+#define Z(i, j) z[j + incz*(i)]
+
+void
+func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
+{
+ INT jj, jb, kk, kb, dk, i, j, k, end;
+ FLOAT r0[8], r1[8], r2[8], r3[8], pf;
+
+ for (i = 0; i < ni; i++) {
+ for (j = 0; j < nj; j++) {
+ Z(i,j) *= b;
+ }
+ }
+
+ jb = MIN(256, nj);
+ kb = MIN(48, nk);
+ for (jj = 0; jj < nj; jj += jb) {
+ for (kk = 0; kk < nk; kk += kb) {
+ for (i = 0; i < ni; i += 4) {
+ for (j = jj; j < jj + jb; j += 8) {
+ r0[0] = Z(i+0, j+0); r0[1] = Z(i+0, j+1); r0[2] = Z(i+0, j+2); r0[3] = Z(i+0, j+3);
+ r1[0] = Z(i+1, j+0); r1[1] = Z(i+1, j+1); r1[2] = Z(i+1, j+2); r1[3] = Z(i+1, j+3);
+ r2[0] = Z(i+2, j+0); r2[1] = Z(i+2, j+1); r2[2] = Z(i+2, j+2); r2[3] = Z(i+2, j+3);
+ r3[0] = Z(i+3, j+0); r3[1] = Z(i+3, j+1); r3[2] = Z(i+3, j+2); r3[3] = Z(i+3, j+3);
+ end = MIN(nk, kk+kb);
+ for (k = kk; k < end; k++) {
+ pf = a * X(i, k);
+ r0[0] += pf * Y(k, j+0); r0[1] += pf * Y(k, j+1); r0[2] += pf * Y(k, j+2); r0[3] += pf * Y(k, j+3);
+
+ pf = a * X(i+1, k);
+ r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3);
+
+ pf = a * X(i+2, k);
+ r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3);
+
+ pf = a * X(i+3, k);
+ r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3);
+ }
+ Z(i+0, j+0) = r0[0]; Z(i+0, j+1) = r0[1]; Z(i+0, j+2) = r0[2]; Z(i+0, j+3) = r0[3];
+ Z(i+1, j+0) = r1[0]; Z(i+1, j+1) = r1[1]; Z(i+1, j+2) = r1[2]; Z(i+1, j+3) = r1[3];
+ Z(i+2, j+0) = r2[0]; Z(i+2, j+1) = r2[1]; Z(i+2, j+2) = r2[2]; Z(i+2, j+3) = r2[3];
+ Z(i+3, j+0) = r3[0]; Z(i+3, j+1) = r3[1]; Z(i+3, j+2) = r3[2]; Z(i+3, j+3) = r3[3];
+ }
+ }
+ }
+ }
+}
+
+#if 0
+void
+func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
+{
+ int i, j, k;
+ FLOAT w[nj*nk], acc[4][4];
+
+ for (i = 0; i < ni; i++) {
+ for (j = 0; j < nj; j++) {
+ Z(i,j) *= b;
+ W(i,j) = Y(j,i);
+ }
+ }
+
+ for (i = 0; i < ni; i+=4) {
+ for (j = 0; j < nj; j+=4) {
+ memset(acc, 0, sizeof(acc));
+ for (k = 0; k < nk; k+=4) {
+ acc[0][0] += X(i+0,k)*W(j+0,k) + X(i+0,k+1)*W(j+0,k+1) + X(i+0,k+2)*W(j+0,k+2) + X(i+0,k+3)*W(j+0,k+3);
+ acc[0][1] += X(i+0,k)*W(j+1,k) + X(i+0,k+1)*W(j+1,k+1) + X(i+0,k+2)*W(j+1,k+2) + X(i+0,k+3)*W(j+1,k+3);
+ acc[0][2] += X(i+0,k)*W(j+2,k) + X(i+0,k+1)*W(j+2,k+1) + X(i+0,k+2)*W(j+2,k+2) + X(i+0,k+3)*W(j+2,k+3);
+ acc[0][3] += X(i+0,k)*W(j+3,k) + X(i+0,k+1)*W(j+3,k+1) + X(i+0,k+2)*W(j+3,k+2) + X(i+0,k+3)*W(j+3,k+3);
+
+ acc[1][0] += X(i+1,k)*W(j+0,k) + X(i+1,k+1)*W(j+0,k+1) + X(i+1,k+2)*W(j+0,k+2) + X(i+1,k+3)*W(j+0,k+3);
+ acc[1][1] += X(i+1,k)*W(j+1,k) + X(i+1,k+1)*W(j+1,k+1) + X(i+1,k+2)*W(j+1,k+2) + X(i+1,k+3)*W(j+1,k+3);
+ acc[1][2] += X(i+1,k)*W(j+2,k) + X(i+1,k+1)*W(j+2,k+1) + X(i+1,k+2)*W(j+2,k+2) + X(i+1,k+3)*W(j+2,k+3);
+ acc[1][3] += X(i+1,k)*W(j+3,k) + X(i+1,k+1)*W(j+3,k+1) + X(i+1,k+2)*W(j+3,k+2) + X(i+1,k+3)*W(j+3,k+3);
+
+ acc[2][0] += X(i+2,k)*W(j+0,k) + X(i+2,k+1)*W(j+0,k+1) + X(i+2,k+2)*W(j+0,k+2) + X(i+2,k+3)*W(j+0,k+3);
+ acc[2][1] += X(i+2,k)*W(j+1,k) + X(i+2,k+1)*W(j+1,k+1) + X(i+2,k+2)*W(j+1,k+2) + X(i+2,k+3)*W(j+1,k+3);
+ acc[2][2] += X(i+2,k)*W(j+2,k) + X(i+2,k+1)*W(j+2,k+1) + X(i+2,k+2)*W(j+2,k+2) + X(i+2,k+3)*W(j+2,k+3);
+ acc[2][3] += X(i+2,k)*W(j+3,k) + X(i+2,k+1)*W(j+3,k+1) + X(i+2,k+2)*W(j+3,k+2) + X(i+2,k+3)*W(j+3,k+3);
+
+ acc[2][0] += X(i+3,k)*W(j+0,k) + X(i+3,k+1)*W(j+0,k+1) + X(i+3,k+2)*W(j+0,k+2) + X(i+3,k+3)*W(j+0,k+3);
+ acc[2][1] += X(i+3,k)*W(j+1,k) + X(i+3,k+1)*W(j+1,k+1) + X(i+3,k+2)*W(j+1,k+2) + X(i+3,k+3)*W(j+1,k+3);
+ acc[2][2] += X(i+3,k)*W(j+2,k) + X(i+3,k+1)*W(j+2,k+1) + X(i+3,k+2)*W(j+2,k+2) + X(i+3,k+3)*W(j+2,k+3);
+ acc[2][3] += X(i+3,k)*W(j+3,k) + X(i+3,k+1)*W(j+3,k+1) + X(i+3,k+2)*W(j+3,k+2) + X(i+3,k+3)*W(j+3,k+3);
+ // Z(i,j) += X(i,k)*Y(k,j);
+ }
+ Z(i+0,j+1) = a*acc[0][0];
+ Z(i+0,j+2) = a*acc[0][1];
+ Z(i+0,j+3) = a*acc[0][2];
+ Z(i+0,j+4) = a*acc[0][3];
+
+ Z(i+1,j+1) = a*acc[1][0];
+ Z(i+1,j+2) = a*acc[1][1];
+ Z(i+1,j+3) = a*acc[1][2];
+ Z(i+1,j+4) = a*acc[1][3];
+
+ Z(i+2,j+1) = a*acc[2][0];
+ Z(i+2,j+2) = a*acc[2][1];
+ Z(i+2,j+3) = a*acc[2][2];
+ Z(i+2,j+4) = a*acc[2][3];
+
+ Z(i+3,j+1) = a*acc[3][0];
+ Z(i+3,j+2) = a*acc[3][1];
+ Z(i+3,j+3) = a*acc[3][2];
+ Z(i+3,j+4) = a*acc[3][3];
+ }
+ }
+}
+#endif
+
+#if 0
+void
+func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
+{
+ int i, j, k, ri, rj, rk;
+ FLOAT reg[4][4], *xrow[4], *yrow[4];
+
+ for (i = 0; i < ni; i++) {
+ for (j = 0; j < nj; j++) {
+ z[j + incz*i] *= b;
+ }
+ }
+
+ for (i = 0; i < ni; i += 4) {
+ xrow[0] = x + incx*(i+0);
+ xrow[1] = x + incx*(i+1);
+ xrow[2] = x + incx*(i+2);
+ xrow[3] = x + incx*(i+3);
+ for (k = 0; k < nk; k+=4) {
+ yrow[0] = y + incy*(k+0);
+ yrow[1] = y + incy*(k+1);
+ yrow[2] = y + incy*(k+2);
+ yrow[3] = y + incy*(k+3);
+ reg[0][0] = a * xrow[0][k+0]; reg[0][1] = a * xrow[0][k+1]; reg[0][2] = a * xrow[0][k+2]; reg[0][3] = a * xrow[0][k+3];
+ reg[1][0] = a * xrow[1][k+0]; reg[1][1] = a * xrow[1][k+1]; reg[1][2] = a * xrow[1][k+2]; reg[1][3] = a * xrow[1][k+3];
+ reg[2][0] = a * xrow[2][k+0]; reg[2][1] = a * xrow[2][k+1]; reg[2][2] = a * xrow[2][k+2]; reg[2][3] = a * xrow[2][k+3];
+ reg[3][0] = a * xrow[3][k+0]; reg[3][1] = a * xrow[3][k+1]; reg[3][2] = a * xrow[3][k+2]; reg[3][3] = a * xrow[3][k+3];
+ for (j = 0; j < nj; j += 1) {
+ z[j + incz*(i+0)] += (reg[0][0]*yrow[0][j]+reg[0][1]*yrow[1][j]+reg[0][2]*yrow[2][j]+reg[0][3]*yrow[3][j]);
+ z[j + incz*(i+1)] += (reg[1][0]*yrow[0][j]+reg[1][1]*yrow[1][j]+reg[1][2]*yrow[2][j]+reg[1][3]*yrow[3][j]);
+ z[j + incz*(i+2)] += (reg[2][0]*yrow[0][j]+reg[2][1]*yrow[1][j]+reg[2][2]*yrow[2][j]+reg[2][3]*yrow[3][j]);
+ z[j + incz*(i+3)] += (reg[3][0]*yrow[0][j]+reg[3][1]*yrow[1][j]+reg[3][2]*yrow[2][j]+reg[3][3]*yrow[3][j]);
+ }
+ }
+ }
+}
+#endif
+
+#if 0
+void
+func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
+{
+ int i, j, k, ri, rj, rk;
+ FLOAT r[4][4], *row[4];
+
+ for (i = 0; i < ni; i++) {
+ for (j = 0; j < nj; j++) {
+ Z(i, j) *= b;
+ }
+ }
+
+ for (i = 0; i < ni; i+=4) {
+ for (j = 0; j < nj; j+=4) {
+ r[0][0] = 0; r[0][1] = 0; r[0][2] = 0; r[0][3] = 0;
+ r[1][0] = 0; r[1][1] = 0; r[1][2] = 0; r[1][3] = 0;
+ r[2][0] = 0; r[2][1] = 0; r[2][2] = 0; r[2][3] = 0;
+ r[3][0] = 0; r[3][1] = 0; r[3][2] = 0; r[3][3] = 0;
+ row[0] = &X(i+0, 0);
+ row[1] = &X(i+1, 0);
+ row[2] = &X(i+2, 0);
+ row[3] = &X(i+3, 0);
+ for (k = 0; k < nk; k++) {
+ r[0][0] += row[0][k]*Y(k,0); r[0][1] += row[0][k]*Y(k,1); r[0][2] += row[0][k]*Y(k,2); r[0][3] += row[0][k]*Y(k,3);
+ r[1][0] += row[1][k]*Y(k,0); r[1][1] += row[1][k]*Y(k,1); r[1][2] += row[1][k]*Y(k,2); r[1][3] += row[1][k]*Y(k,3);
+ r[2][0] += row[2][k]*Y(k,0); r[2][1] += row[2][k]*Y(k,1); r[2][2] += row[2][k]*Y(k,2); r[2][3] += row[2][k]*Y(k,3);
+ r[3][0] += row[3][k]*Y(k,0); r[3][1] += row[3][k]*Y(k,1); r[3][2] += row[3][k]*Y(k,2); r[3][3] += row[3][k]*Y(k,3);
+ }
+ Z(i+0, j+0) += r[0][0]; Z(i+0, j+1) += r[0][1]; Z(i+0, j+2) += r[0][2]; Z(i+0, j+3) += r[0][3];
+ Z(i+1, j+0) += r[1][0]; Z(i+1, j+1) += r[1][1]; Z(i+1, j+2) += r[1][2]; Z(i+1, j+3) += r[1][3];
+ Z(i+2, j+0) += r[2][0]; Z(i+2, j+1) += r[2][1]; Z(i+2, j+2) += r[2][2]; Z(i+2, j+3) += r[2][3];
+ Z(i+3, j+0) += r[3][0]; Z(i+3, j+1) += r[3][1]; Z(i+3, j+2) += r[3][2]; Z(i+3, j+3) += r[3][3];
+ }
+ }
+}
+#endif
+
+#if 0
+void
+func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
+{
+ int i, j, k, ri, rj, rk;
+ FLOAT *xrow[8], *yrow[8], reg;
+
+ for (i = 0; i < ni; i++) {
+ for (j = 0; j < nj; j++) {
+ z[j + incz*i] *= b;
+ }
+ }
+
+ ri = ni & ~7;
+ rj = nj & ~7;
+ for (i = 0; i < ri; i += 8) {
+ xrow[0] = x + incx*(i+0);
+ xrow[1] = x + incx*(i+1);
+ xrow[2] = x + incx*(i+2);
+ xrow[3] = x + incx*(i+3);
+ xrow[4] = x + incx*(i+4);
+ xrow[5] = x + incx*(i+5);
+ xrow[6] = x + incx*(i+6);
+ xrow[7] = x + incx*(i+7);
+ for (j = 0; j < rj; j += 8) {
+ yrow[0] = y + incy*(j+0);
+ yrow[1] = y + incy*(j+1);
+ yrow[2] = y + incy*(j+2);
+ yrow[3] = y + incy*(j+3);
+ yrow[4] = y + incy*(j+4);
+ yrow[5] = y + incy*(j+5);
+ yrow[6] = y + incy*(j+6);
+ yrow[7] = y + incy*(j+7);
+ for (k = 0; k < nk; k++) {
+ reg = a*(yrow[0][k] + yrow[1][k] + yrow[2][k] + yrow[3][k] + yrow[4][k] + yrow[5][k] + yrow[6][k] + yrow[7][k]);
+ z[k + incz*(i+0)] += xrow[0][k]*reg;
+ z[k + incz*(i+1)] += xrow[1][k]*reg;
+ z[k + incz*(i+2)] += xrow[2][k]*reg;
+ z[k + incz*(i+3)] += xrow[3][k]*reg;
+ z[k + incz*(i+4)] += xrow[4][k]*reg;
+ z[k + incz*(i+5)] += xrow[5][k]*reg;
+ z[k + incz*(i+6)] += xrow[6][k]*reg;
+ z[k + incz*(i+7)] += xrow[7][k]*reg;
+ }
+ }
+ for (; j < nj; j++) {
+ for (k = 0; k < nk; k++) {
+ reg = a*y[k+incy*j];
+ z[k + incz*(i+0)] += xrow[0][k]*reg;
+ z[k + incz*(i+1)] += xrow[1][k]*reg;
+ z[k + incz*(i+2)] += xrow[2][k]*reg;
+ z[k + incz*(i+3)] += xrow[3][k]*reg;
+ z[k + incz*(i+4)] += xrow[4][k]*reg;
+ z[k + incz*(i+5)] += xrow[5][k]*reg;
+ z[k + incz*(i+6)] += xrow[6][k]*reg;
+ z[k + incz*(i+7)] += xrow[7][k]*reg;
+ }
+ }
+ }
+
+ for (; i < ni; i++) {
+ for (j = 0; j < rj; j += 8) {
+ yrow[0] = y + incy*(j+0);
+ yrow[1] = y + incy*(j+1);
+ yrow[2] = y + incy*(j+2);
+ yrow[3] = y + incy*(j+3);
+ yrow[4] = y + incy*(j+4);
+ yrow[5] = y + incy*(j+5);
+ yrow[6] = y + incy*(j+6);
+ yrow[7] = y + incy*(j+7);
+ for (k = 0; k < nk; k++) {
+ z[k + incz*(i)] += a*x[k + incx*i]*(yrow[0][k] + yrow[1][k] + yrow[2][k] + yrow[3][k] + yrow[4][k] + yrow[5][k] + yrow[6][k] + yrow[7][k]);
+ }
+ }
+ for (; j < nj; j++) {
+ for (k = 0; k < nk; k++) {
+ z[k + incz*i] += a*x[k + incx*i]*y[k + incy*j];
+ }
+ }
+ }
+}
+#endif