root / tests / testeur_ftrsm.C

Revision 38, 4.9 kB (checked in by pernet, 1 year ago)

New implementation for ftrsm and ftrmm:
based on the multicascade algorithm (cf C Pernet pdf thesis), reducing the number of modular reduction for the updates.
Automatic generation of the code for each of the 48 variations.

Line 
1/* -*- mode: C++; tab-width: 8; indent-tabs-mode: t; c-basic-offset: 8 -*- */
2//--------------------------------------------------------------------------
3//                        Sanity check for ftrsm and ftrmm
4//                 
5//--------------------------------------------------------------------------
6// Clement Pernet 2007
7//-------------------------------------------------------------------------
8
9
10#include <iomanip>
11#include <iostream>
12#include "fflas-ffpack/modular-balanced.h"
13//#include "fflas-ffpack/modular-int.h"
14#include "timer.h"
15#include "Matio.h"
16#include "fflas-ffpack/fflas.h"
17#include "givaro/givintprime.h"
18
19using namespace std;
20 
21//typedef Modular<int> Field;
22//typedef Modular<float> Field;
23typedef Modular<double> Field;
24
25int main(int argc, char** argv){
26
27
28        Timer tim;
29        IntPrimeDom IPD;
30        unsigned long p;
31        size_t M, N, K ;
32        bool keepon = true;
33        Integer _p,tmp;
34        Field::Element zero,one;
35        cerr<<setprecision(10);
36
37        size_t TMAX = 300;
38        size_t PRIMESIZE = 23;
39        if (argc > 1 )
40                TMAX = atoi(argv[1]);
41        if (argc > 2 )
42                PRIMESIZE = atoi(argv[2]);
43
44        FFLAS::FFLAS_TRANSPOSE trans;
45        FFLAS::FFLAS_SIDE side;
46        FFLAS::FFLAS_UPLO uplo;
47        FFLAS::FFLAS_DIAG diag;
48        size_t lda, ldb;
49
50        Field::Element * A, *Abis, *B,* Bbis;
51        Field::Element alpha;
52
53        while (keepon){
54                srandom(_p);
55                do{
56                        //              max = Integer::random(2);
57                        _p = random();//max % (2<<30);
58                        IPD.prevprime( tmp, (_p% (1<<PRIMESIZE)) );
59                        p =  tmp;
60                }while( (p <= 2) );
61               
62                Field F (p); 
63                F.init (zero,0.0);
64                F.init (one,1.0);
65                Field::RandIter RValue (F);
66               
67                do{
68                        M = (size_t)  random() % TMAX;
69                        N = (size_t)  random() % TMAX;
70                } while ((M == 0) || (N == 0));
71
72                ldb = N;
73
74                //if (random()%2)
75                if (1)
76                        trans = FFLAS::FflasNoTrans;
77                else
78                        trans = FFLAS::FflasTrans;
79               
80
81                if (random()%2)
82                        diag = FFLAS::FflasUnit;
83                else
84                        diag = FFLAS::FflasNonUnit;
85
86                if (random()%2){
87                        side = FFLAS::FflasLeft;
88                        K = M;
89                        lda = M;
90                } else {
91                        side = FFLAS::FflasRight;
92                        K = N;
93                        lda = N;
94                }
95
96                if (random()%2)
97                        uplo = FFLAS::FflasUpper;
98                else 
99                        uplo = FFLAS::FflasLower;
100               
101                while (F.isZero(RValue.random (alpha)));
102               
103                A = new Field::Element[K*K];
104                B = new Field::Element[M*N];
105                Abis = new Field::Element[K*K];
106                Bbis = new Field::Element[M*N];
107                for (size_t i = 0; i < M; ++i)
108                        for (size_t j = 0; j < N; ++j){
109                                RValue.random (*(B + i*N + j));
110                                *(Bbis + i*N + j) = *(B + i*N + j);
111                        }
112                for (size_t i = 0; i < K; ++i)
113                        for (size_t j = 0; j < K; ++j)
114                                *(Abis + i*K + j) = RValue.random (*(A + i*K + j));
115                for (size_t i = 0; i < K; ++i){
116                        while (F.isZero(RValue.random (*(A + i*(K+1)))));
117                        *(Abis + i*(K +1)) = *(A + i*(K+1));
118                }
119
120                cout <<"p = "<<(size_t)p
121                     <<" M = "<<M
122                     <<" N = "<<N
123                     <<((side==FFLAS::FflasLeft)?" Left ":" Right ")
124                     <<((uplo==FFLAS::FflasLower)?" Lower ":" Upper ")
125                     <<((trans==FFLAS::FflasTrans)?" Trans ":" NoTrans ")
126                     <<((diag==FFLAS::FflasUnit)?" Unit ":" NonUnit ")
127                     <<"...."; 
128
129                       
130                tim.clear();
131                tim.start();
132                FFLAS::ftrsm (F, side, uplo, trans, diag, M, N, alpha,
133                              A, lda, B, ldb);
134                tim.stop();
135
136                // Verification
137                Field::Element invalpha;
138                F.inv(invalpha, alpha);
139                FFLAS::ftrmm (F, side, uplo, trans, diag, M, N, invalpha,
140                              A, K, B, N); 
141                for (size_t i = 0;i < M;++i)
142                        for (size_t j = 0;j < N; ++j)
143                                if ( !F.areEqual (*(Bbis + i*N+ j ), *(B + i*N + j))){
144                                        cerr<<endl
145                                            <<"Bbis ["<<i<<", "<<j<<"] = "<<(*(Bbis + i*N + j))
146                                            <<" ; B ["<<i<<", "<<j<<"] = "<<(*(B + i*N + j));
147                                           
148                                        keepon = false;
149                                }
150                for (size_t i = 0;i < K; ++i)
151                        for (size_t j = 0;j < K; ++j)
152                                if ( !F.areEqual (*(A + i*K + j), *(Abis + i*K + j))){
153                                        cerr<<endl
154                                            <<"A ["<<i<<", "<<j<<"] = "<<(*(A + i*K + j))
155                                            <<" ; Abis ["<<i<<", "<<j<<"] = "<<(*(Abis + i*K + j));
156                                        keepon = false;
157                                }
158                if (keepon) {
159                        cout<<" Passed "
160                            <<M*N/1000000.0*K/tim.usertime()<<" Mfops"<<endl; 
161                       
162                        delete[] B;
163                        delete[] Bbis;
164                        delete[] A;
165                        delete[] Abis;
166                } else {
167                       
168                        cerr<<endl;
169                        write_field (F, cerr<<"A = "<<endl, Abis, K,K,K);
170                        write_field (F, cerr<<"B = "<<endl, Bbis, M,N,N);
171                }
172        }
173       
174        cout<<endl;
175        cerr<<"FAILED with p = "<<(size_t)p
176            <<" M = "<<M
177            <<" N = "<<N
178            <<" alpha = "<<alpha
179            <<((side==FFLAS::FflasLeft)?" Left ":" Right ")
180            <<((uplo==FFLAS::FflasLower)?" Lower ":" Upper ")
181            <<((trans==FFLAS::FflasTrans)?" Trans ":" NoTrans ")
182            <<((diag==FFLAS::FflasUnit)?" Unit ":" NonUnit ")
183            <<endl;
184       
185        cerr<<"A:"<<endl;
186        cerr<<K<<" "<<K<<" M"<<endl;
187        for (size_t i=0; i<K; ++i)
188                for (size_t j=0; j<K; ++j)
189                        if ((*(Abis + i*lda + j)))
190                                cerr<<i+1<<" "<<j+1<<" "
191                                    <<((int) *(Abis+i*lda+j) )
192                                    <<endl;
193        cerr<<"0 0 0"<<endl<<endl;
194
195        cerr<<"B:"<<endl;
196        cerr<<M<<" "<<N<<" M"<<endl;
197        for (size_t i=0; i<M; ++i)
198                for (size_t j=0; j<N; ++j)
199                        if ((*(Bbis + i*ldb + j)))
200                                cerr<<i+1<<" "<<j+1<<" "
201                                    <<((int) *(Bbis+i*ldb+j) )
202                                    <<endl;
203        cerr<<"0 0 0"<<endl<<endl;
204
205        delete[] A;
206        delete[] Abis;
207        delete[] B;
208        delete[] Bbis;
209}
Note: See TracBrowser for help on using the browser.