root / tests / testeur_fgemm.C

Revision 63, 5.5 kB (checked in by pernet, 7 months ago)

Update finite field usage

Line 
1/* -*- mode: C++; tab-width: 8; indent-tabs-mode: t; c-basic-offset: 8 -*- */
2//--------------------------------------------------------------------------
3//                        Test for the  fgemm winograd
4//                 
5//--------------------------------------------------------------------------
6// Clement Pernet
7//-------------------------------------------------------------------------
8
9#define NEWWINO
10
11#include <iostream>
12#include <iomanip>
13using namespace std;
14//#include "fflas-ffpack/modular-int.h"
15//#include "fflas-ffpack/modular-positive.h"
16#include "fflas-ffpack/modular-balanced.h"
17#include "timer.h"
18#include "Matio.h"
19#include "fflas-ffpack/fflas.h"
20#include "givaro/givintprime.h"
21
22
23
24//typedef ModularBalanced<float> Field;
25typedef ModularBalanced<double> Field;
26//typedef Modular<double> Field;
27//typedef Modular<float> Field;
28//typedef Modular<int> Field;
29//typedef GivaroZpz<Std32> Field;
30//typedef GivaroGfq Field;
31
32int main(int argc, char** argv){
33        Timer tim;
34        IntPrimeDom IPD;
35        Field::Element alpha, beta;
36        long p;
37        size_t M, K, N, Wino;
38        bool keepon = true;
39        Integer _p,tmp;
40        cerr<<setprecision(10);
41        size_t TMAX = 100;
42        size_t PRIMESIZE = 23;
43        size_t WINOMAX = 8;
44       
45        if (argc > 1 )
46                TMAX = atoi(argv[1]);
47        if (argc > 2 )
48                PRIMESIZE = atoi(argv[2]);
49        if (argc > 3 )
50                WINOMAX = atoi(argv[3]);
51
52        enum FFLAS::FFLAS_TRANSPOSE ta, tb;
53        size_t lda,ldb;
54        Field::Element * A;
55        Field::Element * B;
56        Field::Element * C, *Cbis, *Cter;
57       
58        while (keepon){
59                srandom(_p);
60                do{
61                        //              max = Integer::random(2);
62                        _p = random();//max % (2<<30);
63                        IPD.prevprime( tmp, (_p% (1<<PRIMESIZE)) );
64                        p =  tmp;
65                       
66                }while( (p <= 2) );
67               
68                Field F( p ); 
69                Field::RandIter RValue( F );
70                //NonzeroRandIter<Field> RnValue( F, RValue );
71               
72               
73                do{
74                        M = (size_t)  random() % TMAX;
75                        K = (size_t)  random() % TMAX;
76                        N = (size_t)  random() % TMAX;
77                        Wino = random() % WINOMAX;
78                } while (!( (K>>Wino > 0) && (M>>Wino > 0) && (N>>Wino > 0) ));
79
80                if (random()%2){
81                        ta = FFLAS::FflasTrans;
82                        lda = M;
83                }
84                else{
85                        ta = FFLAS::FflasNoTrans;
86                        lda = K;
87                }
88                if (random()%2){
89                        tb = FFLAS::FflasTrans;
90                        ldb = K;
91                }
92                else{
93                        tb = FFLAS::FflasNoTrans;
94                        ldb = N;
95                }
96               
97                A = new Field::Element[M*K];
98                B = new Field::Element[K*N];
99                C = new Field::Element[M*N];
100                Cbis = new Field::Element[M*N];
101                Cter = new Field::Element[M*N];
102               
103                for( size_t i = 0; i < M*K; ++i )
104                        RValue.random( *(A+i) );
105                for( size_t i = 0; i < K*N; ++i )
106                        RValue.random( *(B+i) );
107                for( size_t i = 0; i < M*N; ++i )
108                        *(Cter+i) = *(Cbis+i)= RValue.random( *(C+i) );
109               
110                RValue.random( alpha );
111                RValue.random( beta );
112               
113                cout <<"p = "<<(size_t)p<<" M = "<<M
114                     <<" N = "<<N<<" K = "<<K<<" Winolevel = "<<Wino<<" "
115                     <<alpha
116                     <<((ta==FFLAS::FflasNoTrans)?".Ax":".A^Tx")
117                     <<((tb==FFLAS::FflasNoTrans)?"B + ":"B^T + ")
118                     <<beta<<".C"
119                     <<"...."; 
120
121                tim.clear();
122                tim.start();
123                FFLAS::fgemm (F, ta, tb, M, N, K, alpha, A, lda, B, ldb, beta, C, N, Wino);
124                tim.stop();
125//              for (int j = 0; j < n; ++j ){
126//                      FFLAS::fgemv( F, FFLAS::FflasNoTrans, m, k, alpha, A, k, B+j, n, beta, Cbis+j, n);
127//                      for (int i=0; i<m; ++i)
128//                              if ( !F.areEqual( *(Cbis+i*n+j), *(C+i*n+j) ) )
129//                                      keepon = false;
130//              }
131                Field::Element aij, bij, boa, temp;
132                //F.div(boa, beta, alpha);
133                for (int i = 0; i < M; ++i )
134                        for ( int j = 0; j < N; ++j ){
135                                //                              F.mulin(*(Cbis+i*N+j),boa);
136                                F.mulin(*(Cbis+i*N+j),beta);
137                                for ( int l = 0; l < K ; ++l ){
138                                        if ( ta == FFLAS::FflasNoTrans )
139                                                aij = *(A+i*lda+l);
140                                        else
141                                                aij = *(A+l*lda+i);
142                                        if ( tb == FFLAS::FflasNoTrans )
143                                                bij = *(B+l*ldb+j);
144                                        else
145                                                bij = *(B+j*ldb+l);
146                                        F.mul(temp,aij,bij);
147                                        F.axpyin( *(Cbis+i*N+j), alpha, temp);
148                                        //F.axpyin( *(Cbis+i*N+j), aij, bij );
149                                }
150                                //F.mulin( *(Cbis+i*N+j),alpha );
151                                if ( !F.areEqual( *(Cbis+i*N+j), *(C+i*N+j) ) ) {
152                                        cerr<<"error for i,j="<<i<<" "<<j<<" "<<*(C+i*N+j)<<" "<<*(Cbis+i*N+j)<<endl;
153                                        keepon = false;
154                                }
155                        }
156               
157                if (keepon){
158                        cout<<"Passed "
159                            <<(2*M*N/1000.0*K/tim.usertime()/1000.0)<<"Mfops"<<endl; 
160                        delete[] A;
161                        delete[] B;
162                        delete[] C;
163                        delete[] Cbis;
164                        delete[] Cter;
165                }
166                else{
167                        // cerr<<"C="<<endl;
168//                      write_field( F, cerr, C, M, N, N );
169//                      cerr<<"Cbis="<<endl;
170//                      write_field( F, cerr, Cbis, M, N, N );
171                }
172        }
173        cout<<endl;
174        cerr<<"FAILED with p = "<<(size_t)p<<" M = "<<M<<" N = "<<N<<" K = "<<K
175            <<" Winolevel = "<<Wino
176            <<" alpha = "<<(size_t)alpha<<" beta = "<<(size_t)beta<<endl; 
177        cerr<<"A:"<<endl;
178        if ( ta ==FFLAS::FflasNoTrans ){
179                cerr<<M<<" "<<K<<" M"<<endl;
180                for (size_t i=0; i<M; ++i)
181                        for (size_t j=0; j<K; ++j)
182                                cerr<<i+1<<" "<<j+1<<" "<<((size_t) *(A+i*lda+j) )<<endl;
183        }
184        else{
185                cerr<<K<<" "<<M<<" M"<<endl;
186                for (size_t i=0; i<K; ++i)
187                        for (size_t j=0; j<M; ++j)
188                                cerr<<i+1<<" "<<j+1<<" "<<((size_t) *(A+j*lda+i) )<<endl;
189
190        }
191        cerr<<"0 0 0"<<endl<<endl;
192        cerr<<"B:"<<endl;
193        if ( tb ==FFLAS::FflasNoTrans ){
194                cerr<<K<<" "<<N<<" M"<<endl;
195                for (size_t i=0; i<K; ++i)
196                        for (size_t j=0; j<N; ++j)
197                                cerr<<i+1<<" "<<j+1<<" "<<((size_t) *(B+i*ldb+j) )<<endl;
198        }
199        else{
200                cerr<<N<<" "<<K<<" M"<<endl;
201                for (size_t i=0; i<N; ++i)
202                        for (size_t j=0; j<K; ++j)
203                                cerr<<i+1<<" "<<j+1<<" "<<((size_t) *(B+i+j*ldb) )<<endl;
204        }
205        cerr<<"0 0 0"<<endl<<endl;
206        cerr<<"C:"<<endl
207            <<M<<" "<<N<<" M"<<endl;
208        for (size_t i=0; i<M; ++i)
209                for (size_t j=0; j<N; ++j)
210                        cerr<<i+1<<" "<<j+1<<" "<<((size_t) *(Cter+i*N+j) )<<endl;
211        cerr<<"0 0 0"<<endl;
212
213        delete[] A;
214        delete[] B;
215        delete[] C;
216        delete[] Cbis;
217        delete[] Cter;
218}
219
220
221
222
223
224
225
226
227
228
229
230
231
Note: See TracBrowser for help on using the browser.