root / tests / test-fgemm.C

Revision 63, 4.3 kB (checked in by pernet, 7 months ago)

Update finite field usage

Line 
1/* -*- mode: C++; tab-width: 8; indent-tabs-mode: t; c-basic-offset: 8 -*- */
2//--------------------------------------------------------------------------
3//                        Test for fgemm : 1 computation
4//                 
5//--------------------------------------------------------------------------
6// Clement Pernet
7//-------------------------------------------------------------------------
8
9#define DEBUG 1
10#define NEWWINO
11#define TIME 1
12
13#include <iomanip>
14#include <iostream>
15using namespace std;
16
17//#include "fflas-ffpack/modular-positive.h"
18//#include "fflas-ffpack/modular-balanced.h"
19#include "fflas-ffpack/modular-int.h"
20#include "timer.h"
21#include "Matio.h"
22#include "fflas-ffpack/fflas.h"
23
24
25
26
27//typedef Modular<double> Field;
28//typedef Modular<float> Field;
29typedef ModularBalanced<double> Field;
30//typedef ModularBalanced<float> Field;
31//typedef Modular<int> Field;
32
33int main(int argc, char** argv){
34
35        int m,n,k;
36        int nbw=atoi(argv[4]); // number of winograd levels
37        int nbit=atoi(argv[5]); // number of times the product is performed
38        cerr<<setprecision(10);
39        Field::Element alpha,beta;
40
41
42        if (argc != 11) {
43                cerr<<"Usage : test-fgemm <p> <A> <B> <w> <i>"
44                    <<" <alpha> <beta> <C> <ta> <tb>"<<endl
45                    <<"         to do i computations of C <- alpha AB + beta C"
46                    <<" using w recursive levels of Winograd's algorithm"
47                    <<endl
48                    <<"         if ta=1 (resp tb=1), A (resp B) is transposed."
49                    <<endl;
50                exit(-1);
51        }
52        Field F((long)atoi(argv[1]));
53
54        F.init( alpha, Field::Element(atoi(argv[6])));
55        F.init( beta, Field::Element(atoi(argv[7])));
56
57        Field::Element * A;
58        Field::Element * B;
59        size_t lda;
60        size_t ldb;
61       
62        enum FFLAS::FFLAS_TRANSPOSE ta = FFLAS::FflasNoTrans;
63        enum FFLAS::FFLAS_TRANSPOSE tb = FFLAS::FflasNoTrans;
64        if (atoi(argv[9])){
65                ta = FFLAS::FflasTrans;
66                A = read_field(F,argv[2],&k,&m);
67                        lda = m;
68        }
69        else{
70                A = read_field(F,argv[2],&m,&k);
71                lda = k;
72        }
73        if (atoi(argv[10])){
74                tb = FFLAS::FflasTrans;
75                B = read_field(F,argv[3],&n,&k);
76                ldb = k;
77        }
78        else{
79                B = read_field(F,argv[3],&k,&n);
80                ldb = n;
81        }
82       
83        Field::Element * C=NULL;
84
85//      write_field (F, cerr<<"A = "<<endl, A, m, k, lda);
86//      write_field (F, cerr<<"B = "<<endl, B, k, n, ldb);
87        Timer tim,t; t.clear();tim.clear(); 
88        for(int i = 0;i<nbit;++i){
89                if (!F.isZero(beta)){
90                        C = read_field(F,argv[8],&m,&n);
91                }else
92                        C = new Field::Element[m*n];
93                t.clear();
94                t.start();
95                FFLAS::fgemm (F, ta, tb,m,n,k,alpha, A,lda, B,ldb,
96                              beta,C,n,nbw);
97                t.stop();
98                tim+=t;
99                if (i<nbit-1)
100                        delete[] C;
101        }
102
103#if DEBUG
104        bool wrong = false;
105        Field::Element zero;
106        F.init(zero, 0.0);
107        Field::Element * Cd;
108        if (!F.isZero(beta))
109                Cd = read_field(F,argv[8],&m,&n);
110        else{
111                Cd  = new Field::Element[m*n];
112                for (int i=0; i<m*n; ++i)
113                        F.assign (*(Cd+i), zero);
114        }
115        Field::Element aij, bij, beta_alpha, tmp;
116        //F.div (beta_alpha, beta, alpha);
117        for (int i = 0; i < m; ++i)
118                for (int j = 0; j < n; ++j){
119                        F.mulin(*(Cd+i*n+j),beta);
120                        F.assign (tmp, zero);
121                        for ( int l = 0; l < k ; ++l ){
122                                if ( ta == FFLAS::FflasNoTrans )
123                                        aij = *(A+i*lda+l);
124                                else
125                                        aij = *(A+l*lda+i);
126                                if ( tb == FFLAS::FflasNoTrans )
127                                        bij = *(B+l*ldb+j);
128                                else
129                                        bij = *(B+j*ldb+l);
130                                //F.mul (tmp, aij, bij);
131                                //F.axpyin( *(Cd+i*n+j), alpha, tmp );
132                                F.axpyin (tmp, aij, bij); 
133                        }
134                        F.axpyin (*(Cd+i*n+j), alpha, tmp);
135                        //F.mulin( *(Cd+i*n+j),alpha );
136                        if ( !F.areEqual( *(Cd+i*n+j), *(C+i*n+j) ) ) {
137                                wrong = true;
138                        }
139                }
140        if ( wrong ){
141                cerr<<"FAIL"<<endl;
142                for (int i=0; i<m; ++i){
143                        for (int j =0; j<n; ++j)
144                                if (!F.areEqual( *(C+i*n+j), *(Cd+i*n+j) ) )
145                                         cerr<<"Erreur C["<<i<<","<<j<<"]="
146                                             <<(*(C+i*n+j))<<" C[d"<<i<<","<<j<<"]="
147                                             <<(*(Cd+i*n+j))<<endl;
148                }
149        }
150        else{
151                cerr<<"PASS"<<endl;
152        }
153        delete[] Cd;
154#endif
155
156        delete[] C;
157        delete[] A;
158        delete[] B;
159#if TIME
160        double mflops = (2.0*(m*k-((!F.isZero(beta))?m:0))/1000000.0)*nbit*n/tim.usertime();
161        cerr << nbw << " Winograd's level over Z/"<<atoi(argv[1])<<"Z : t= "
162             << tim.usertime()/nbit 
163             << " s, Mffops = "<<mflops
164             << endl;
165       
166        cerr<<"m,n,k,nbw = "<<m<<", "<<n<<", "<<k<<", "<<alpha
167            <<", "<<beta<<", "<<nbw<<endl
168            <<alpha
169            <<((ta==FFLAS::FflasNoTrans)?".Ax":".A^Tx")
170            <<((tb==FFLAS::FflasNoTrans)?"B + ":"B^T + ")
171            <<beta<<".C"<<endl;
172        cout<<m<<" "<<n<<" "<<k<<" "<<nbw<<" "<<alpha<<" "<<beta<<" "
173            <<mflops<<" "<<tim.usertime()/nbit<<endl;
174#endif
175} 
176
Note: See TracBrowser for help on using the browser.