#include #include #include #define INT int #define FLOAT double #define func(name) blas·d##name #define X(i, j) x[j + incx*(i)] #define Y(i, j) y[j + incy*(i)] #define Z(i, j) z[j + incz*(i)] void func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz) { INT jj, jb, kk, kb, dk, i, j, k, end; FLOAT r0[8], r1[8], r2[8], r3[8], pf; for (i = 0; i < ni; i++) { for (j = 0; j < nj; j++) { Z(i,j) *= b; } } jb = MIN(256, nj); kb = MIN(48, nk); for (jj = 0; jj < nj; jj += jb) { for (kk = 0; kk < nk; kk += kb) { for (i = 0; i < ni; i += 4) { for (j = jj; j < jj + jb; j += 8) { r0[0] = Z(i+0, j+0); r0[1] = Z(i+0, j+1); r0[2] = Z(i+0, j+2); r0[3] = Z(i+0, j+3); r1[0] = Z(i+1, j+0); r1[1] = Z(i+1, j+1); r1[2] = Z(i+1, j+2); r1[3] = Z(i+1, j+3); r2[0] = Z(i+2, j+0); r2[1] = Z(i+2, j+1); r2[2] = Z(i+2, j+2); r2[3] = Z(i+2, j+3); r3[0] = Z(i+3, j+0); r3[1] = Z(i+3, j+1); r3[2] = Z(i+3, j+2); r3[3] = Z(i+3, j+3); end = MIN(nk, kk+kb); for (k = kk; k < end; k++) { pf = a * X(i, k); r0[0] += pf * Y(k, j+0); r0[1] += pf * Y(k, j+1); r0[2] += pf * Y(k, j+2); r0[3] += pf * Y(k, j+3); pf = a * X(i+1, k); r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3); pf = a * X(i+2, k); r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3); pf = a * X(i+3, k); r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3); } Z(i+0, j+0) = r0[0]; Z(i+0, j+1) = r0[1]; Z(i+0, j+2) = r0[2]; Z(i+0, j+3) = r0[3]; Z(i+1, j+0) = r1[0]; Z(i+1, j+1) = r1[1]; Z(i+1, j+2) = r1[2]; Z(i+1, j+3) = r1[3]; Z(i+2, j+0) = r2[0]; Z(i+2, j+1) = r2[1]; Z(i+2, j+2) = r2[2]; Z(i+2, j+3) = r2[3]; Z(i+3, j+0) = r3[0]; Z(i+3, j+1) = r3[1]; Z(i+3, j+2) = r3[2]; Z(i+3, j+3) = r3[3]; } } } } } #if 0 void func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz) { int i, j, k; FLOAT w[nj*nk], acc[4][4]; for (i = 0; i < ni; i++) { for (j = 0; j < nj; j++) { Z(i,j) *= b; W(i,j) = Y(j,i); } } for (i = 0; i < ni; i+=4) { for (j = 0; j < nj; j+=4) { memset(acc, 0, sizeof(acc)); for (k = 0; k < nk; k+=4) { acc[0][0] += X(i+0,k)*W(j+0,k) + X(i+0,k+1)*W(j+0,k+1) + X(i+0,k+2)*W(j+0,k+2) + X(i+0,k+3)*W(j+0,k+3); acc[0][1] += X(i+0,k)*W(j+1,k) + X(i+0,k+1)*W(j+1,k+1) + X(i+0,k+2)*W(j+1,k+2) + X(i+0,k+3)*W(j+1,k+3); acc[0][2] += X(i+0,k)*W(j+2,k) + X(i+0,k+1)*W(j+2,k+1) + X(i+0,k+2)*W(j+2,k+2) + X(i+0,k+3)*W(j+2,k+3); acc[0][3] += X(i+0,k)*W(j+3,k) + X(i+0,k+1)*W(j+3,k+1) + X(i+0,k+2)*W(j+3,k+2) + X(i+0,k+3)*W(j+3,k+3); acc[1][0] += X(i+1,k)*W(j+0,k) + X(i+1,k+1)*W(j+0,k+1) + X(i+1,k+2)*W(j+0,k+2) + X(i+1,k+3)*W(j+0,k+3); acc[1][1] += X(i+1,k)*W(j+1,k) + X(i+1,k+1)*W(j+1,k+1) + X(i+1,k+2)*W(j+1,k+2) + X(i+1,k+3)*W(j+1,k+3); acc[1][2] += X(i+1,k)*W(j+2,k) + X(i+1,k+1)*W(j+2,k+1) + X(i+1,k+2)*W(j+2,k+2) + X(i+1,k+3)*W(j+2,k+3); acc[1][3] += X(i+1,k)*W(j+3,k) + X(i+1,k+1)*W(j+3,k+1) + X(i+1,k+2)*W(j+3,k+2) + X(i+1,k+3)*W(j+3,k+3); acc[2][0] += X(i+2,k)*W(j+0,k) + X(i+2,k+1)*W(j+0,k+1) + X(i+2,k+2)*W(j+0,k+2) + X(i+2,k+3)*W(j+0,k+3); acc[2][1] += X(i+2,k)*W(j+1,k) + X(i+2,k+1)*W(j+1,k+1) + X(i+2,k+2)*W(j+1,k+2) + X(i+2,k+3)*W(j+1,k+3); acc[2][2] += X(i+2,k)*W(j+2,k) + X(i+2,k+1)*W(j+2,k+1) + X(i+2,k+2)*W(j+2,k+2) + X(i+2,k+3)*W(j+2,k+3); acc[2][3] += X(i+2,k)*W(j+3,k) + X(i+2,k+1)*W(j+3,k+1) + X(i+2,k+2)*W(j+3,k+2) + X(i+2,k+3)*W(j+3,k+3); acc[2][0] += X(i+3,k)*W(j+0,k) + X(i+3,k+1)*W(j+0,k+1) + X(i+3,k+2)*W(j+0,k+2) + X(i+3,k+3)*W(j+0,k+3); acc[2][1] += X(i+3,k)*W(j+1,k) + X(i+3,k+1)*W(j+1,k+1) + X(i+3,k+2)*W(j+1,k+2) + X(i+3,k+3)*W(j+1,k+3); acc[2][2] += X(i+3,k)*W(j+2,k) + X(i+3,k+1)*W(j+2,k+1) + X(i+3,k+2)*W(j+2,k+2) + X(i+3,k+3)*W(j+2,k+3); acc[2][3] += X(i+3,k)*W(j+3,k) + X(i+3,k+1)*W(j+3,k+1) + X(i+3,k+2)*W(j+3,k+2) + X(i+3,k+3)*W(j+3,k+3); // Z(i,j) += X(i,k)*Y(k,j); } Z(i+0,j+1) = a*acc[0][0]; Z(i+0,j+2) = a*acc[0][1]; Z(i+0,j+3) = a*acc[0][2]; Z(i+0,j+4) = a*acc[0][3]; Z(i+1,j+1) = a*acc[1][0]; Z(i+1,j+2) = a*acc[1][1]; Z(i+1,j+3) = a*acc[1][2]; Z(i+1,j+4) = a*acc[1][3]; Z(i+2,j+1) = a*acc[2][0]; Z(i+2,j+2) = a*acc[2][1]; Z(i+2,j+3) = a*acc[2][2]; Z(i+2,j+4) = a*acc[2][3]; Z(i+3,j+1) = a*acc[3][0]; Z(i+3,j+2) = a*acc[3][1]; Z(i+3,j+3) = a*acc[3][2]; Z(i+3,j+4) = a*acc[3][3]; } } } #endif #if 0 void func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz) { int i, j, k, ri, rj, rk; FLOAT reg[4][4], *xrow[4], *yrow[4]; for (i = 0; i < ni; i++) { for (j = 0; j < nj; j++) { z[j + incz*i] *= b; } } for (i = 0; i < ni; i += 4) { xrow[0] = x + incx*(i+0); xrow[1] = x + incx*(i+1); xrow[2] = x + incx*(i+2); xrow[3] = x + incx*(i+3); for (k = 0; k < nk; k+=4) { yrow[0] = y + incy*(k+0); yrow[1] = y + incy*(k+1); yrow[2] = y + incy*(k+2); yrow[3] = y + incy*(k+3); reg[0][0] = a * xrow[0][k+0]; reg[0][1] = a * xrow[0][k+1]; reg[0][2] = a * xrow[0][k+2]; reg[0][3] = a * xrow[0][k+3]; reg[1][0] = a * xrow[1][k+0]; reg[1][1] = a * xrow[1][k+1]; reg[1][2] = a * xrow[1][k+2]; reg[1][3] = a * xrow[1][k+3]; reg[2][0] = a * xrow[2][k+0]; reg[2][1] = a * xrow[2][k+1]; reg[2][2] = a * xrow[2][k+2]; reg[2][3] = a * xrow[2][k+3]; reg[3][0] = a * xrow[3][k+0]; reg[3][1] = a * xrow[3][k+1]; reg[3][2] = a * xrow[3][k+2]; reg[3][3] = a * xrow[3][k+3]; for (j = 0; j < nj; j += 1) { z[j + incz*(i+0)] += (reg[0][0]*yrow[0][j]+reg[0][1]*yrow[1][j]+reg[0][2]*yrow[2][j]+reg[0][3]*yrow[3][j]); z[j + incz*(i+1)] += (reg[1][0]*yrow[0][j]+reg[1][1]*yrow[1][j]+reg[1][2]*yrow[2][j]+reg[1][3]*yrow[3][j]); z[j + incz*(i+2)] += (reg[2][0]*yrow[0][j]+reg[2][1]*yrow[1][j]+reg[2][2]*yrow[2][j]+reg[2][3]*yrow[3][j]); z[j + incz*(i+3)] += (reg[3][0]*yrow[0][j]+reg[3][1]*yrow[1][j]+reg[3][2]*yrow[2][j]+reg[3][3]*yrow[3][j]); } } } } #endif #if 0 void func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz) { int i, j, k, ri, rj, rk; FLOAT r[4][4], *row[4]; for (i = 0; i < ni; i++) { for (j = 0; j < nj; j++) { Z(i, j) *= b; } } for (i = 0; i < ni; i+=4) { for (j = 0; j < nj; j+=4) { r[0][0] = 0; r[0][1] = 0; r[0][2] = 0; r[0][3] = 0; r[1][0] = 0; r[1][1] = 0; r[1][2] = 0; r[1][3] = 0; r[2][0] = 0; r[2][1] = 0; r[2][2] = 0; r[2][3] = 0; r[3][0] = 0; r[3][1] = 0; r[3][2] = 0; r[3][3] = 0; row[0] = &X(i+0, 0); row[1] = &X(i+1, 0); row[2] = &X(i+2, 0); row[3] = &X(i+3, 0); for (k = 0; k < nk; k++) { r[0][0] += row[0][k]*Y(k,0); r[0][1] += row[0][k]*Y(k,1); r[0][2] += row[0][k]*Y(k,2); r[0][3] += row[0][k]*Y(k,3); r[1][0] += row[1][k]*Y(k,0); r[1][1] += row[1][k]*Y(k,1); r[1][2] += row[1][k]*Y(k,2); r[1][3] += row[1][k]*Y(k,3); r[2][0] += row[2][k]*Y(k,0); r[2][1] += row[2][k]*Y(k,1); r[2][2] += row[2][k]*Y(k,2); r[2][3] += row[2][k]*Y(k,3); r[3][0] += row[3][k]*Y(k,0); r[3][1] += row[3][k]*Y(k,1); r[3][2] += row[3][k]*Y(k,2); r[3][3] += row[3][k]*Y(k,3); } Z(i+0, j+0) += r[0][0]; Z(i+0, j+1) += r[0][1]; Z(i+0, j+2) += r[0][2]; Z(i+0, j+3) += r[0][3]; Z(i+1, j+0) += r[1][0]; Z(i+1, j+1) += r[1][1]; Z(i+1, j+2) += r[1][2]; Z(i+1, j+3) += r[1][3]; Z(i+2, j+0) += r[2][0]; Z(i+2, j+1) += r[2][1]; Z(i+2, j+2) += r[2][2]; Z(i+2, j+3) += r[2][3]; Z(i+3, j+0) += r[3][0]; Z(i+3, j+1) += r[3][1]; Z(i+3, j+2) += r[3][2]; Z(i+3, j+3) += r[3][3]; } } } #endif #if 0 void func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz) { int i, j, k, ri, rj, rk; FLOAT *xrow[8], *yrow[8], reg; for (i = 0; i < ni; i++) { for (j = 0; j < nj; j++) { z[j + incz*i] *= b; } } ri = ni & ~7; rj = nj & ~7; for (i = 0; i < ri; i += 8) { xrow[0] = x + incx*(i+0); xrow[1] = x + incx*(i+1); xrow[2] = x + incx*(i+2); xrow[3] = x + incx*(i+3); xrow[4] = x + incx*(i+4); xrow[5] = x + incx*(i+5); xrow[6] = x + incx*(i+6); xrow[7] = x + incx*(i+7); for (j = 0; j < rj; j += 8) { yrow[0] = y + incy*(j+0); yrow[1] = y + incy*(j+1); yrow[2] = y + incy*(j+2); yrow[3] = y + incy*(j+3); yrow[4] = y + incy*(j+4); yrow[5] = y + incy*(j+5); yrow[6] = y + incy*(j+6); yrow[7] = y + incy*(j+7); for (k = 0; k < nk; k++) { reg = a*(yrow[0][k] + yrow[1][k] + yrow[2][k] + yrow[3][k] + yrow[4][k] + yrow[5][k] + yrow[6][k] + yrow[7][k]); z[k + incz*(i+0)] += xrow[0][k]*reg; z[k + incz*(i+1)] += xrow[1][k]*reg; z[k + incz*(i+2)] += xrow[2][k]*reg; z[k + incz*(i+3)] += xrow[3][k]*reg; z[k + incz*(i+4)] += xrow[4][k]*reg; z[k + incz*(i+5)] += xrow[5][k]*reg; z[k + incz*(i+6)] += xrow[6][k]*reg; z[k + incz*(i+7)] += xrow[7][k]*reg; } } for (; j < nj; j++) { for (k = 0; k < nk; k++) { reg = a*y[k+incy*j]; z[k + incz*(i+0)] += xrow[0][k]*reg; z[k + incz*(i+1)] += xrow[1][k]*reg; z[k + incz*(i+2)] += xrow[2][k]*reg; z[k + incz*(i+3)] += xrow[3][k]*reg; z[k + incz*(i+4)] += xrow[4][k]*reg; z[k + incz*(i+5)] += xrow[5][k]*reg; z[k + incz*(i+6)] += xrow[6][k]*reg; z[k + incz*(i+7)] += xrow[7][k]*reg; } } } for (; i < ni; i++) { for (j = 0; j < rj; j += 8) { yrow[0] = y + incy*(j+0); yrow[1] = y + incy*(j+1); yrow[2] = y + incy*(j+2); yrow[3] = y + incy*(j+3); yrow[4] = y + incy*(j+4); yrow[5] = y + incy*(j+5); yrow[6] = y + incy*(j+6); yrow[7] = y + incy*(j+7); for (k = 0; k < nk; k++) { z[k + incz*(i)] += a*x[k + incx*i]*(yrow[0][k] + yrow[1][k] + yrow[2][k] + yrow[3][k] + yrow[4][k] + yrow[5][k] + yrow[6][k] + yrow[7][k]); } } for (; j < nj; j++) { for (k = 0; k < nk; k++) { z[k + incz*i] += a*x[k + incx*i]*y[k + incy*j]; } } } } #endif