aboutsummaryrefslogtreecommitdiff
path: root/src/libmath/blas3.c
blob: b048c95fe1951f146bf79af898ff3ba4f83de2d2 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
#include <u.h>
#include <base.h>
#include <libmath.h>

#define INT   int
#define FLOAT double
#define func(name) blas·d##name

#define X(i, j) x[j + incx*(i)]
#define Y(i, j) y[j + incy*(i)]
#define Z(i, j) z[j + incz*(i)]

void
func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
{
    INT jj, jb, kk, kb, dk, i, j, k, end;
    FLOAT r0[8], r1[8], r2[8], r3[8], pf;

    for (i = 0; i < ni; i++) {
        for (j = 0; j < nj; j++) {
            Z(i,j) *= b;
        }
    }

    jb = MIN(256, nj);
    kb = MIN(48, nk);
    for (jj = 0; jj < nj; jj += jb) {
        for (kk = 0; kk < nk; kk += kb) {
            for (i = 0; i < ni; i += 4) {
                for (j = jj; j < jj + jb; j += 8) {
                    r0[0] = Z(i+0, j+0); r0[1] = Z(i+0, j+1); r0[2] = Z(i+0, j+2); r0[3] = Z(i+0, j+3);
                    r1[0] = Z(i+1, j+0); r1[1] = Z(i+1, j+1); r1[2] = Z(i+1, j+2); r1[3] = Z(i+1, j+3);
                    r2[0] = Z(i+2, j+0); r2[1] = Z(i+2, j+1); r2[2] = Z(i+2, j+2); r2[3] = Z(i+2, j+3);
                    r3[0] = Z(i+3, j+0); r3[1] = Z(i+3, j+1); r3[2] = Z(i+3, j+2); r3[3] = Z(i+3, j+3);
                    end = MIN(nk, kk+kb);
                    for (k = kk; k < end; k++) {
                        pf    = a  * X(i, k);
                        r0[0] += pf * Y(k, j+0); r0[1] += pf * Y(k, j+1); r0[2] += pf * Y(k, j+2); r0[3] += pf * Y(k, j+3);

                        pf    = a * X(i+1, k);
                        r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3);

                        pf    = a * X(i+2, k);
                        r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3);

                        pf    = a * X(i+3, k);
                        r1[0] += pf * Y(k, j+0); r1[1] += pf * Y(k, j+1); r1[2] += pf * Y(k, j+2); r1[3] += pf * Y(k, j+3);
                    }
                    Z(i+0, j+0) = r0[0]; Z(i+0, j+1) = r0[1]; Z(i+0, j+2) = r0[2]; Z(i+0, j+3) = r0[3];
                    Z(i+1, j+0) = r1[0]; Z(i+1, j+1) = r1[1]; Z(i+1, j+2) = r1[2]; Z(i+1, j+3) = r1[3];
                    Z(i+2, j+0) = r2[0]; Z(i+2, j+1) = r2[1]; Z(i+2, j+2) = r2[2]; Z(i+2, j+3) = r2[3];
                    Z(i+3, j+0) = r3[0]; Z(i+3, j+1) = r3[1]; Z(i+3, j+2) = r3[2]; Z(i+3, j+3) = r3[3];
                }
            }
        }
    }
}

#if 0
void
func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
{
    int i, j, k;
    FLOAT w[nj*nk], acc[4][4];

    for (i = 0; i < ni; i++) {
        for (j = 0; j < nj; j++) {
            Z(i,j) *= b;
            W(i,j) = Y(j,i);
        }
    }

    for (i = 0; i < ni; i+=4) {
        for (j = 0; j < nj; j+=4) {
            memset(acc, 0, sizeof(acc));
            for (k = 0; k < nk; k+=4) {
                acc[0][0] += X(i+0,k)*W(j+0,k) + X(i+0,k+1)*W(j+0,k+1) + X(i+0,k+2)*W(j+0,k+2) + X(i+0,k+3)*W(j+0,k+3);
                acc[0][1] += X(i+0,k)*W(j+1,k) + X(i+0,k+1)*W(j+1,k+1) + X(i+0,k+2)*W(j+1,k+2) + X(i+0,k+3)*W(j+1,k+3);
                acc[0][2] += X(i+0,k)*W(j+2,k) + X(i+0,k+1)*W(j+2,k+1) + X(i+0,k+2)*W(j+2,k+2) + X(i+0,k+3)*W(j+2,k+3);
                acc[0][3] += X(i+0,k)*W(j+3,k) + X(i+0,k+1)*W(j+3,k+1) + X(i+0,k+2)*W(j+3,k+2) + X(i+0,k+3)*W(j+3,k+3);

                acc[1][0] += X(i+1,k)*W(j+0,k) + X(i+1,k+1)*W(j+0,k+1) + X(i+1,k+2)*W(j+0,k+2) + X(i+1,k+3)*W(j+0,k+3);
                acc[1][1] += X(i+1,k)*W(j+1,k) + X(i+1,k+1)*W(j+1,k+1) + X(i+1,k+2)*W(j+1,k+2) + X(i+1,k+3)*W(j+1,k+3);
                acc[1][2] += X(i+1,k)*W(j+2,k) + X(i+1,k+1)*W(j+2,k+1) + X(i+1,k+2)*W(j+2,k+2) + X(i+1,k+3)*W(j+2,k+3);
                acc[1][3] += X(i+1,k)*W(j+3,k) + X(i+1,k+1)*W(j+3,k+1) + X(i+1,k+2)*W(j+3,k+2) + X(i+1,k+3)*W(j+3,k+3);

                acc[2][0] += X(i+2,k)*W(j+0,k) + X(i+2,k+1)*W(j+0,k+1) + X(i+2,k+2)*W(j+0,k+2) + X(i+2,k+3)*W(j+0,k+3);
                acc[2][1] += X(i+2,k)*W(j+1,k) + X(i+2,k+1)*W(j+1,k+1) + X(i+2,k+2)*W(j+1,k+2) + X(i+2,k+3)*W(j+1,k+3);
                acc[2][2] += X(i+2,k)*W(j+2,k) + X(i+2,k+1)*W(j+2,k+1) + X(i+2,k+2)*W(j+2,k+2) + X(i+2,k+3)*W(j+2,k+3);
                acc[2][3] += X(i+2,k)*W(j+3,k) + X(i+2,k+1)*W(j+3,k+1) + X(i+2,k+2)*W(j+3,k+2) + X(i+2,k+3)*W(j+3,k+3);

                acc[2][0] += X(i+3,k)*W(j+0,k) + X(i+3,k+1)*W(j+0,k+1) + X(i+3,k+2)*W(j+0,k+2) + X(i+3,k+3)*W(j+0,k+3);
                acc[2][1] += X(i+3,k)*W(j+1,k) + X(i+3,k+1)*W(j+1,k+1) + X(i+3,k+2)*W(j+1,k+2) + X(i+3,k+3)*W(j+1,k+3);
                acc[2][2] += X(i+3,k)*W(j+2,k) + X(i+3,k+1)*W(j+2,k+1) + X(i+3,k+2)*W(j+2,k+2) + X(i+3,k+3)*W(j+2,k+3);
                acc[2][3] += X(i+3,k)*W(j+3,k) + X(i+3,k+1)*W(j+3,k+1) + X(i+3,k+2)*W(j+3,k+2) + X(i+3,k+3)*W(j+3,k+3);
                // Z(i,j) += X(i,k)*Y(k,j);
            }
            Z(i+0,j+1) = a*acc[0][0];
            Z(i+0,j+2) = a*acc[0][1];
            Z(i+0,j+3) = a*acc[0][2];
            Z(i+0,j+4) = a*acc[0][3];

            Z(i+1,j+1) = a*acc[1][0];
            Z(i+1,j+2) = a*acc[1][1];
            Z(i+1,j+3) = a*acc[1][2];
            Z(i+1,j+4) = a*acc[1][3];

            Z(i+2,j+1) = a*acc[2][0];
            Z(i+2,j+2) = a*acc[2][1];
            Z(i+2,j+3) = a*acc[2][2];
            Z(i+2,j+4) = a*acc[2][3];

            Z(i+3,j+1) = a*acc[3][0];
            Z(i+3,j+2) = a*acc[3][1];
            Z(i+3,j+3) = a*acc[3][2];
            Z(i+3,j+4) = a*acc[3][3];
        }
    }
}
#endif

#if 0
void
func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
{
    int i, j, k, ri, rj, rk;
    FLOAT reg[4][4], *xrow[4], *yrow[4];

    for (i = 0; i < ni; i++) {
        for (j = 0; j < nj; j++) {
            z[j + incz*i] *= b;
        }
    }

    for (i = 0; i < ni; i += 4) {
        xrow[0] = x + incx*(i+0);
        xrow[1] = x + incx*(i+1);
        xrow[2] = x + incx*(i+2);
        xrow[3] = x + incx*(i+3);
        for (k = 0; k < nk; k+=4) {
            yrow[0]   = y + incy*(k+0);
            yrow[1]   = y + incy*(k+1);
            yrow[2]   = y + incy*(k+2);
            yrow[3]   = y + incy*(k+3);
            reg[0][0] = a * xrow[0][k+0]; reg[0][1] = a * xrow[0][k+1]; reg[0][2] = a * xrow[0][k+2]; reg[0][3] = a * xrow[0][k+3];
            reg[1][0] = a * xrow[1][k+0]; reg[1][1] = a * xrow[1][k+1]; reg[1][2] = a * xrow[1][k+2]; reg[1][3] = a * xrow[1][k+3];
            reg[2][0] = a * xrow[2][k+0]; reg[2][1] = a * xrow[2][k+1]; reg[2][2] = a * xrow[2][k+2]; reg[2][3] = a * xrow[2][k+3];
            reg[3][0] = a * xrow[3][k+0]; reg[3][1] = a * xrow[3][k+1]; reg[3][2] = a * xrow[3][k+2]; reg[3][3] = a * xrow[3][k+3];
            for (j = 0; j < nj; j += 1) {
                z[j + incz*(i+0)] += (reg[0][0]*yrow[0][j]+reg[0][1]*yrow[1][j]+reg[0][2]*yrow[2][j]+reg[0][3]*yrow[3][j]);
                z[j + incz*(i+1)] += (reg[1][0]*yrow[0][j]+reg[1][1]*yrow[1][j]+reg[1][2]*yrow[2][j]+reg[1][3]*yrow[3][j]);
                z[j + incz*(i+2)] += (reg[2][0]*yrow[0][j]+reg[2][1]*yrow[1][j]+reg[2][2]*yrow[2][j]+reg[2][3]*yrow[3][j]);
                z[j + incz*(i+3)] += (reg[3][0]*yrow[0][j]+reg[3][1]*yrow[1][j]+reg[3][2]*yrow[2][j]+reg[3][3]*yrow[3][j]);
            }
        }
    }
}
#endif

#if 0
void
func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
{
    int i, j, k, ri, rj, rk;
    FLOAT r[4][4], *row[4];

    for (i = 0; i < ni; i++) {
        for (j = 0; j < nj; j++) {
            Z(i, j) *= b;
        }
    }

    for (i = 0; i < ni; i+=4) {
        for (j = 0; j < nj; j+=4) {
            r[0][0] = 0; r[0][1] = 0; r[0][2] = 0; r[0][3] = 0; 
            r[1][0] = 0; r[1][1] = 0; r[1][2] = 0; r[1][3] = 0; 
            r[2][0] = 0; r[2][1] = 0; r[2][2] = 0; r[2][3] = 0; 
            r[3][0] = 0; r[3][1] = 0; r[3][2] = 0; r[3][3] = 0; 
            row[0]  = &X(i+0, 0);
            row[1]  = &X(i+1, 0);
            row[2]  = &X(i+2, 0);
            row[3]  = &X(i+3, 0);
            for (k = 0; k < nk; k++) {
                r[0][0] += row[0][k]*Y(k,0); r[0][1] += row[0][k]*Y(k,1); r[0][2] += row[0][k]*Y(k,2); r[0][3] += row[0][k]*Y(k,3);
                r[1][0] += row[1][k]*Y(k,0); r[1][1] += row[1][k]*Y(k,1); r[1][2] += row[1][k]*Y(k,2); r[1][3] += row[1][k]*Y(k,3);
                r[2][0] += row[2][k]*Y(k,0); r[2][1] += row[2][k]*Y(k,1); r[2][2] += row[2][k]*Y(k,2); r[2][3] += row[2][k]*Y(k,3);
                r[3][0] += row[3][k]*Y(k,0); r[3][1] += row[3][k]*Y(k,1); r[3][2] += row[3][k]*Y(k,2); r[3][3] += row[3][k]*Y(k,3);
            }
            Z(i+0, j+0) += r[0][0]; Z(i+0, j+1) += r[0][1]; Z(i+0, j+2) += r[0][2]; Z(i+0, j+3) += r[0][3];
            Z(i+1, j+0) += r[1][0]; Z(i+1, j+1) += r[1][1]; Z(i+1, j+2) += r[1][2]; Z(i+1, j+3) += r[1][3];
            Z(i+2, j+0) += r[2][0]; Z(i+2, j+1) += r[2][1]; Z(i+2, j+2) += r[2][2]; Z(i+2, j+3) += r[2][3];
            Z(i+3, j+0) += r[3][0]; Z(i+3, j+1) += r[3][1]; Z(i+3, j+2) += r[3][2]; Z(i+3, j+3) += r[3][3];
        }
    }
}
#endif

#if 0
void
func(gemm)(uint trm, uint trn, INT ni, INT nj, INT nk, FLOAT a, FLOAT *x, INT incx, FLOAT *y, INT incy, FLOAT b, FLOAT *z, INT incz)
{
    int i, j, k, ri, rj, rk;
    FLOAT *xrow[8], *yrow[8], reg;

    for (i = 0; i < ni; i++) {
        for (j = 0; j < nj; j++) {
            z[j + incz*i] *= b;
        }
    }

    ri = ni & ~7;
    rj = nj & ~7;
    for (i = 0; i < ri; i += 8) {
        xrow[0] = x + incx*(i+0);
        xrow[1] = x + incx*(i+1);
        xrow[2] = x + incx*(i+2);
        xrow[3] = x + incx*(i+3);
        xrow[4] = x + incx*(i+4);
        xrow[5] = x + incx*(i+5);
        xrow[6] = x + incx*(i+6);
        xrow[7] = x + incx*(i+7);
        for (j = 0; j < rj; j += 8) {
            yrow[0] = y + incy*(j+0);
            yrow[1] = y + incy*(j+1);
            yrow[2] = y + incy*(j+2);
            yrow[3] = y + incy*(j+3);
            yrow[4] = y + incy*(j+4);
            yrow[5] = y + incy*(j+5);
            yrow[6] = y + incy*(j+6);
            yrow[7] = y + incy*(j+7);
            for (k = 0; k < nk; k++) {
                reg = a*(yrow[0][k] + yrow[1][k] + yrow[2][k] + yrow[3][k] + yrow[4][k] + yrow[5][k] + yrow[6][k] + yrow[7][k]);
                z[k + incz*(i+0)] += xrow[0][k]*reg;
                z[k + incz*(i+1)] += xrow[1][k]*reg;
                z[k + incz*(i+2)] += xrow[2][k]*reg;
                z[k + incz*(i+3)] += xrow[3][k]*reg;
                z[k + incz*(i+4)] += xrow[4][k]*reg;
                z[k + incz*(i+5)] += xrow[5][k]*reg;
                z[k + incz*(i+6)] += xrow[6][k]*reg;
                z[k + incz*(i+7)] += xrow[7][k]*reg;
            }
        }
        for (; j < nj; j++) {
            for (k = 0; k < nk; k++) {
                reg = a*y[k+incy*j];
                z[k + incz*(i+0)] += xrow[0][k]*reg;
                z[k + incz*(i+1)] += xrow[1][k]*reg;
                z[k + incz*(i+2)] += xrow[2][k]*reg;
                z[k + incz*(i+3)] += xrow[3][k]*reg;
                z[k + incz*(i+4)] += xrow[4][k]*reg;
                z[k + incz*(i+5)] += xrow[5][k]*reg;
                z[k + incz*(i+6)] += xrow[6][k]*reg;
                z[k + incz*(i+7)] += xrow[7][k]*reg;
            }
        }
    }

    for (; i < ni; i++) {
        for (j = 0; j < rj; j += 8) {
            yrow[0] = y + incy*(j+0);
            yrow[1] = y + incy*(j+1);
            yrow[2] = y + incy*(j+2);
            yrow[3] = y + incy*(j+3);
            yrow[4] = y + incy*(j+4);
            yrow[5] = y + incy*(j+5);
            yrow[6] = y + incy*(j+6);
            yrow[7] = y + incy*(j+7);
            for (k = 0; k < nk; k++) {
                z[k + incz*(i)] += a*x[k + incx*i]*(yrow[0][k] + yrow[1][k] + yrow[2][k] + yrow[3][k] + yrow[4][k] + yrow[5][k] + yrow[6][k] + yrow[7][k]);
            }
        }
        for (; j < nj; j++) {
            for (k = 0; k < nk; k++) {
                z[k + incz*i] += a*x[k + incx*i]*y[k + incy*j];
            }
        }
    }
}
#endif