aboutsummaryrefslogtreecommitdiff
path: root/src/libmath/blas.c
blob: 18f97607e2f9d52685e7ad1d5a67c54d2ce9e742 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
#include <u.h>
#include <base.h>
#include <libmath.h>
#include <libmath/blas.h>
#include <time.h>

/* #include <vendor/blas/cblas.h> */

#define NCOL 2*512
#define NROW 2*512
#define NSUM 2*512
#define NIT  10
#define INC  1
error
main()
{
    int i, j, nit;
    double *x, *y, *z, *w, res[2];

    clock_t t;
    double tprof[2] = { 0 };

    rng·init(0);

    x = malloc(sizeof(*x)*NROW*NCOL);
    y = malloc(sizeof(*x)*NROW*NCOL);
    z = malloc(sizeof(*x)*NROW*NCOL);
    w = malloc(sizeof(*x)*NROW*NCOL);

#define DO_0  t = clock();                                  \
              blas·dgemm(0,0,NROW,NCOL,NSUM,10.1,x,NROW,y,NROW,1.2,z,NROW);\
              t = clock() - t;                              \
              res[0] += blas·dasum(NROW*NCOL,z,INC);        \
              tprof[0] += 1000.*t/CLOCKS_PER_SEC;           \

#define DO_1 t = clock();                                   \
             cblas_dgemm(CblasRowMajor,CblasNoTrans,CblasNoTrans,NROW,NCOL,NSUM,10.1,x,NROW,y,NROW,1.2,w,NROW);\
             t = clock() - t;                               \
             res[1] += cblas_dasum(NROW*NCOL,w,INC);        \
             tprof[1] += 1000.*t/CLOCKS_PER_SEC;      

    for (nit = 0; nit < NIT; nit++) {
        for (i = 0; i < NROW; i++) {
            for (j = 0; j < NCOL; j++) {
                x[j + NROW*i] = rng·random();
                y[j + NROW*i] = rng·random();
                z[j + NROW*i] = rng·random();
                w[j + NROW*i] = z[j + NROW*i];
            }
        }

        switch (nit % 2) {
            case 0: DO_0; DO_1; break;
            case 1: DO_1; DO_0; break;
        }
    }
    printf("mean time/iteration (mine): %fms\n", tprof[0]/NIT);
    printf("--> result (mine): %f\n", res[0]);
    printf("mean time/iteration (openblas): %fms\n", tprof[1]/NIT);
    printf("--> result (openblas): %f\n", res[1]);

    return 0;
}