Show
Ignore:
Timestamp:
09/26/07 17:13:25 (1 year ago)
Author:
pernet
Message:

Make test-fgesv work. No bugs known.

Files:
1 modified

Legend:

Unmodified
Added
Removed
  • include/fflas-ffpack/ffpack.h

    r48 r51  
    177177         */ 
    178178        template <class Field> 
    179         static typename Field::Element * 
     179        static void 
    180180        fgetrs (const Field& F, 
    181181                const FFLAS_SIDE Side, 
    182182                const size_t M, const size_t N, const size_t R, 
    183                 const Field::Element *A, const size_t lda, 
     183                typename Field::Element *A, const size_t lda, 
    184184                const size_t *P, const size_t *Q, 
    185                 Field::Element *B, const size_t ldb, 
     185                typename Field::Element *B, const size_t ldb, 
    186186                int * info){ 
    187187 
    188                 static zero; 
     188                static typename Field::Element zero, one, mone; 
    189189                F.init (zero, 0.0); 
     190                F.init (one, 1.0); 
     191                F.neg(mone, one); 
     192                 
    190193                if (Side == FflasLeft) { // Left looking solve A X = B 
    191194 
    192                         SolveLB (F, FflasLeft, M, N, R, A, lda, Q, B, ldb); 
     195                        solveLB2 (F, FflasLeft, M, N, R, A, lda, Q, B, ldb); 
    193196 
    194197                        applyP (F, FflasLeft, FflasNoTrans, N, 0, R, B, ldb, Q); 
     
    212215                         
    213216                        applyP (F, FflasLeft, FflasTrans, N, 0, R, B, ldb, P); 
    214  
    215                         return B; 
    216217                         
    217218                } else { // Right Looking X A = B 
     
    220221                         
    221222                        ftrsm (F, FflasRight, FflasUpper, FflasNoTrans, FflasNonUnit,  
    222                                N, R, one, A, lda , B, ldb); 
     223                               M, R, one, A, lda , B, ldb); 
    223224 
    224225                        fgemm (F, FflasNoTrans, FflasNoTrans, M, N-R, R, one, 
     
    238239                        applyP (F, FflasRight, FflasNoTrans, M, 0, R, B, ldb, Q); 
    239240 
    240                         SolveLB (F, FflasRight, M, N, R, A, lda, Q, B, ldb); 
     241                        solveLB2 (F, FflasRight, M, N, R, A, lda, Q, B, ldb); 
    241242                } 
    242243        } 
     
    262263         * @param B Right/Left hand side matrix.  
    263264         * @param ldb leading dimension of B 
    264          * @info Succes of the computation: 0 if successfull, >0 if system is inconsistent 
     265         * @param info Succes of the computation: 0 if successfull, >0 if system is inconsistent 
    265266         */ 
    266267        template <class Field> 
     
    269270                const FFLAS_SIDE Side, 
    270271                const size_t M, const size_t N, const size_t NRHS, const size_t R, 
    271                 const Field::Element *A, const size_t lda, 
     272                typename Field::Element *A, const size_t lda, 
    272273                const size_t *P, const size_t *Q, 
    273                 Field::Element *X, const size_t ldb, 
    274                 const Field::Element *B, const size_t ldb, 
     274                typename Field::Element *X, const size_t ldx, 
     275                const typename Field::Element *B, const size_t ldb, 
    275276                int * info) { 
    276277 
    277                 static zero; 
     278                static typename Field::Element zero, one, mone; 
    278279                F.init (zero, 0.0); 
     280                F.init (one, 1.0); 
     281                F.neg(mone, one); 
     282 
     283                typename Field::Element* W; 
     284                size_t ldw; 
     285 
    279286                if (Side == FflasLeft) { // Left looking solve A X = B 
    280  
    281                         for (size_t i=0; i < N; ++i) 
    282                                 fcopy (F, NRHS, X + i*ldx, 1, B + i*ldb, 1); 
     287                         
     288                        // Initializing X to 0 (to be optimized) 
     289                        for (size_t i = 0; i <N; ++i) 
     290                                for (size_t j=0; j< NRHS; ++j) 
     291                                        F.assign (*(X+i*ldx+j), zero); 
     292 
     293                        if (M > N){ // Cannot copy B into X 
     294                                W = new typename Field::Element [M*NRHS]; 
     295                                ldw = NRHS; 
     296                                for (size_t i=0; i < M; ++i) 
     297                                        fcopy (F, NRHS, W + i*ldw, 1, B + i*ldb, 1); 
    283298                                
    284                         SolveLB (F, FflasLeft, M, N, R, A, lda, Q, B, ldb); 
    285  
    286                         applyP (F, FflasLeft, FflasNoTrans, N, 0, R, B, ldb, Q); 
    287  
    288                         bool consistent = true; 
    289                         for (size_t i = R; i < M; ++i) 
    290                                 for (size_t j = 0; j < N; ++j) 
    291                                         if (!F.isZero (*(B + i*ldb + j))) 
    292                                                 consistent = false; 
    293                         if (!consistent) { 
    294                                 std::cerr<<"System is inconsistent"<<std::endl; 
    295                                 *info = 1; 
    296                         } 
    297                         // The last rows of B are now supposed to be 0 
    298                         //                      for (size_t i = R; i < M; ++i) 
    299                         //                              for (size_t j = 0; j < N; ++j) 
    300                         //                                      *(B + i*ldb + j) = zero; 
    301  
    302                         ftrsm (F, FflasLeft, FflasUpper, FflasNoTrans, FflasNonUnit,  
    303                                R, N, one, A, lda , B, ldb); 
    304                          
    305                         applyP (F, FflasLeft, FflasTrans, N, 0, R, B, ldb, P); 
    306  
    307                         return B; 
     299                                solveLB2 (F, FflasLeft, M, NRHS, R, A, lda, Q, W, ldw); 
     300                                 
     301                                applyP (F, FflasLeft, FflasNoTrans, N, 0, R, W, ldw, Q); 
     302 
     303                                bool consistent = true; 
     304                                for (size_t i = R; i < M; ++i) 
     305                                        for (size_t j = 0; j < NRHS; ++j) 
     306                                                if (!F.isZero (*(W + i*ldw + j))) 
     307                                                        consistent = false; 
     308                                if (!consistent) { 
     309                                        std::cerr<<"System is inconsistent"<<std::endl; 
     310                                        *info = 1; 
     311                                        return X; 
     312                                } 
     313                                // Here the last rows of W are supposed to be 0 
     314                                 
     315                                ftrsm (F, FflasLeft, FflasUpper, FflasNoTrans, FflasNonUnit,  
     316                                       R, NRHS, one, A, lda , W, ldw); 
     317                         
     318                                for (size_t i=0; i < R; ++i) 
     319                                        fcopy (F, NRHS, X + i*ldx, 1, W + i*ldw, 1); 
     320 
     321                                delete[] W; 
     322                                applyP (F, FflasLeft, FflasTrans, NRHS, 0, R, X, ldx, P); 
     323                                 
     324                        } else { // Copy B to X directly 
     325                                for (size_t i=0; i < M; ++i) 
     326                                        fcopy (F, NRHS, X + i*ldx, 1, B + i*ldb, 1); 
     327                                
     328                                solveLB2 (F, FflasLeft, M, NRHS, R, A, lda, Q, X, ldx); 
     329                                 
     330                                applyP (F, FflasLeft, FflasNoTrans, N, 0, R, X, ldx, Q); 
     331 
     332                                bool consistent = true; 
     333                                for (size_t i = R; i < M; ++i) 
     334                                        for (size_t j = 0; j < NRHS; ++j) 
     335                                                if (!F.isZero (*(X + i*ldx + j))) 
     336                                                        consistent = false; 
     337                                if (!consistent) { 
     338                                        std::cerr<<"System is inconsistent"<<std::endl; 
     339                                        *info = 1; 
     340                                        return X; 
     341                                } 
     342                                // Here the last rows of W are supposed to be 0 
     343                                                                 
     344                                ftrsm (F, FflasLeft, FflasUpper, FflasNoTrans, FflasNonUnit,  
     345                                       R, NRHS, one, A, lda , X, ldx); 
     346                         
     347                                applyP (F, FflasLeft, FflasTrans, NRHS, 0, R, X, ldx, P); 
     348                        } 
     349 
     350                        return X; 
    308351                         
    309352                } else { // Right Looking X A = B 
    310353 
    311                         for (size_t i=0; i < NRHS; ++i) 
    312                                 fcopy (F, M, X + i*ldx, 1, B + i*ldb, 1); 
    313  
    314                         applyP (F, FflasRight, FflasTrans, M, 0, R, B, ldb, P); 
    315                          
    316                         ftrsm (F, FflasRight, FflasUpper, FflasNoTrans, FflasNonUnit,  
    317                                N, R, one, A, lda , B, ldb); 
    318  
    319                         fgemm (F, FflasNoTrans, FflasNoTrans, M, N-R, R, one, 
    320                                B, ldb, A+R, lda, mone, B+R, ldb); 
    321  
    322                         bool consistent = true; 
    323                         for (size_t i = 0; i < M; ++i) 
    324                                 for (size_t j = R; j < N; ++j) 
    325                                         if (!F.isZero (*(B + i*ldb + j))) 
    326                                                 consistent = false; 
    327                         if (!consistent) { 
    328                                 std::cerr<<"System is inconsistent"<<std::endl; 
    329                                 *info = 1; 
    330                         } 
    331                         // The last cols of B are now supposed to be 0 
    332  
    333                         applyP (F, FflasRight, FflasNoTrans, M, 0, R, B, ldb, Q); 
    334  
    335                         SolveLB (F, FflasRight, M, N, R, A, lda, Q, B, ldb); 
     354                        for (size_t i = 0; i <NRHS; ++i) 
     355                                for (size_t j=0; j< M; ++j) 
     356                                        F.assign (*(X+i*ldx+j), zero); 
     357 
     358                        if (M < N) { 
     359                                W = new typename Field::Element [NRHS*N]; 
     360                                ldw = N; 
     361                                for (size_t i=0; i < NRHS; ++i) 
     362                                        fcopy (F, N, W + i*ldw, 1, B + i*ldb, 1); 
     363 
     364                                applyP (F, FflasRight, FflasTrans, NRHS, 0, R, W, ldw, P); 
     365                         
     366                                ftrsm (F, FflasRight, FflasUpper, FflasNoTrans, FflasNonUnit,  
     367                                       NRHS, R, one, A, lda , W, ldw); 
     368                                 
     369                                fgemm (F, FflasNoTrans, FflasNoTrans, NRHS, N-R, R, one, 
     370                                       W, ldw, A+R, lda, mone, W+R, ldw); 
     371 
     372                                bool consistent = true; 
     373                                for (size_t i = 0; i < NRHS; ++i) 
     374                                        for (size_t j = R; j < N; ++j) 
     375                                                if (!F.isZero (*(W + i*ldw + j))) 
     376                                                        consistent = false; 
     377                                if (!consistent) { 
     378                                        std::cerr<<"System is inconsistent"<<std::endl; 
     379                                        *info = 1; 
     380                                        return X; 
     381                                } 
     382                                // The last N-R cols of W are now supposed to be 0 
     383                                for (size_t i=0; i < NRHS; ++i) 
     384                                        fcopy (F, R, X + i*ldx, 1, W + i*ldb, 1); 
     385 
     386                                applyP (F, FflasRight, FflasNoTrans, NRHS, 0, R, X, ldx, Q); 
     387 
     388                                solveLB2 (F, FflasRight, NRHS, M, R, A, lda, Q, X, ldx); 
     389                                 
     390                        } else { 
     391                                for (size_t i=0; i < NRHS; ++i) 
     392                                        fcopy (F, N, X + i*ldx, 1, B + i*ldb, 1); 
     393                                 
     394                                applyP (F, FflasRight, FflasTrans, NRHS, 0, R, X, ldx, P); 
     395                         
     396                                ftrsm (F, FflasRight, FflasUpper, FflasNoTrans, FflasNonUnit,  
     397                                       NRHS, R, one, A, lda , X, ldx); 
     398                                 
     399                                fgemm (F, FflasNoTrans, FflasNoTrans, NRHS, N-R, R, one, 
     400                                       X, ldx, A+R, lda, mone, X+R, ldx); 
     401 
     402                                bool consistent = true; 
     403                                for (size_t i = 0; i < NRHS; ++i) 
     404                                        for (size_t j = R; j < N; ++j) 
     405                                                if (!F.isZero (*(X + i*ldx + j))) 
     406                                                        consistent = false; 
     407                                if (!consistent) { 
     408                                        std::cerr<<"System is inconsistent"<<std::endl; 
     409                                        *info = 1; 
     410                                        return X; 
     411                                } 
     412                                // The last N-R cols of W are now supposed to be 0 
     413 
     414                                applyP (F, FflasRight, FflasNoTrans, NRHS, 0, R, X, ldx, Q); 
     415                                 
     416                                solveLB2 (F, FflasRight, NRHS, M, R, A, lda, Q, X, ldx); 
     417                                 
     418                        } 
     419                        return X; 
    336420                } 
    337421        } 
    338  
    339422        /** 
    340423         * @brief Square system solver 
     
    349432         * @param B Right/Left hand side matrix. Initially contains B, finally contains the solution X. 
    350433         * @param ldb leading dimension of B 
    351          * @info Succes of the computation: 0 if successfull, >0 if system is inconsistent 
    352          * @return a pointer to B 
     434         * @param info Success of the computation: 0 if successfull, >0 if system is inconsistent 
     435         * @return the rank of the system 
    353436         *  
    354437         * Solve the system A X = B or X A = B. 
     
    358441         */ 
    359442        template <class Field> 
    360         static typename Field::Element * 
    361         fgesv (const FFLAS_SIDE Side, 
     443        static size_t  
     444        fgesv (const Field& F, 
     445               const FFLAS_SIDE Side, 
    362446               const size_t M, const size_t N, 
    363                const Field::Element *A, const size_t lda, 
    364                Field::Element *B, const size_t ldb, 
     447               typename Field::Element *A, const size_t lda, 
     448               typename Field::Element *B, const size_t ldb, 
    365449               int * info){ 
    366450 
     
    371455                        Na = N; 
    372456                 
    373                 size_t P = new size_t[Na]; 
    374                 size_t Q = new size_t[Na]; 
     457                size_t* P = new size_t[Na]; 
     458                size_t* Q = new size_t[Na]; 
    375459 
    376460                size_t R = LUdivine (F, FflasNonUnit, FflasNoTrans, Na, Na, A, lda, P, Q, FfpackLQUP); 
    377461 
    378                 fgetrs (F, Side, M, N , R, A, lda, P, Q, B, ldb, info); 
     462                fgetrs (F, Side, M, N, R, A, lda, P, Q, B, ldb, info); 
    379463                 
    380464                delete[] P; 
    381465                delete[] Q; 
    382466 
    383                 return B; 
     467                return R; 
     468        } 
     469         
     470        /** 
     471         * @brief Rectangular system solver 
     472         * @param Field The computation domain 
     473         * @param Side Determine wheter the resolution is left or right looking 
     474         * @param M row dimension of A 
     475         * @param N col dimension of A 
     476         * @param NRHS number of columns (if Side = FflasLeft) or row (if Side = FflasRight) of the matrices X and B 
     477         * @param A input matrix 
     478         * @param lda leading dimension of A 
     479         * @param P column permutation of the LQUP decomposition of A 
     480         * @param Q column permutation of the LQUP decomposition of A 
     481         * @param B Right/Left hand side matrix. Initially contains B, finally contains the solution X. 
     482         * @param ldb leading dimension of B 
     483         * @info Success of the computation: 0 if successfull, >0 if system is inconsistent 
     484         * @return the rank of the system 
     485         *  
     486         * Solve the system A X = B or X A = B. 
     487         * Version for A square. 
     488         * If A is rank deficient, a solution is returned if the system is consistent, 
     489         * Otherwise an info is 1 
     490         */ 
     491        template <class Field> 
     492        static size_t  
     493        fgesv (const Field& F, 
     494               const FFLAS_SIDE Side, 
     495               const size_t M, const size_t N, const size_t NRHS, 
     496               typename Field::Element *A, const size_t lda, 
     497               typename Field::Element *X, const size_t ldx, 
     498               const typename Field::Element *B, const size_t ldb, 
     499               int * info){ 
     500 
     501                size_t Nb,Mb; 
     502                if (Side == FflasLeft){Nb = NRHS; Mb = N;} 
     503                else {Nb = M; Mb = NRHS;} 
     504                 
     505                size_t* P = new size_t[N]; 
     506                size_t* Q = new size_t[M]; 
     507 
     508                size_t R = LUdivine (F, FflasNonUnit, FflasNoTrans, M, N, A, lda, P, Q, FfpackLQUP); 
     509 
     510                fgetrs (F, Side, M, N, NRHS, R, A, lda, P, Q, X, ldx, B, ldb, info); 
     511                 
     512                delete[] P; 
     513                delete[] Q; 
     514 
     515                return R; 
    384516        } 
    385517         
     
    10361168                F.init( one, 1.0 ); 
    10371169                typename Field::Element * Lcurr,* Rcurr,* Bcurr; 
    1038                 size_t ib, k, Ldim; 
    1039                 //cerr<<"In solveLB"<<endl; 
     1170                size_t ib, Ldim; 
     1171                int k; 
    10401172                if ( Side == FflasLeft ){ 
    10411173                        size_t j = 0; 
    10421174                        while ( j<R ) { 
    10431175                                k = ib = Q[j]; 
    1044                                 //cerr<<"j avant="<<j<<endl; 
    1045                                 while ( (Q[j] == k) && (j<R) ) {k++;j++;} 
     1176                                while ( (Q[j] == (size_t)k) && (j<R) ) {k++;j++;} 
    10461177                                Ldim = k-ib; 
    1047                                 //cerr<<"k, ib, j, R "<<k<<" "<<ib<<" "<<j<<" "<<R<<endl; 
    1048                                 //cerr<<"M,k="<<M<<" "<<k<<endl; 
    1049                                 //cerr<<" ftrsm with M, N="<<Ldim<<" "<<N<<endl; 
    10501178                                Lcurr = L + j-Ldim + ib*ldl; 
    10511179                                Bcurr = B + ib*ldb; 
    10521180                                Rcurr = Lcurr + Ldim*ldl; 
    1053                                 ftrsm( F, Side, FflasLower, FflasNoTrans, FflasUnit, Ldim, N, one, Lcurr, ldl , Bcurr, ldb ); 
    1054                                 //cerr<<"M,k="<<M<<" "<<k<<endl; 
    1055                                 //cerr<<" fgemm with M, N, K="<<M-k<<" "<<N<<" "<<Ldim<<endl; 
    1056                                 fgemm( F, FflasNoTrans, FflasNoTrans, M-k, N, Ldim, Mone, Rcurr , ldl, Bcurr, ldb, one, Bcurr+Ldim*ldb, ldb); 
     1181 
     1182                                ftrsm( F, Side, FflasLower, FflasNoTrans, FflasUnit, Ldim, N, one, 
     1183                                       Lcurr, ldl , Bcurr, ldb ); 
     1184 
     1185                                fgemm( F, FflasNoTrans, FflasNoTrans, M-k, N, Ldim, Mone, 
     1186                                       Rcurr , ldl, Bcurr, ldb, one, Bcurr+Ldim*ldb, ldb); 
    10571187                        } 
    10581188                } 
    10591189                else{ // Side == FflasRight 
    10601190                        int j=R-1; 
    1061                         while ( j >=0 ) { 
    1062                                 //cerr<<"j="<<j<<endl; 
     1191                        while ( j >= 0 ) { 
    10631192                                k = ib = Q[j]; 
    1064                                 while ( (j>=0) &&  (Q[j] == k)  ) {--k;--j;} 
     1193                                while ( (j >= 0) &&  ( Q[j] == k)  ) {--k;--j;} 
    10651194                                Ldim = ib-k; 
    1066                                 //cerr<<"Ldim, ib, k, N = "<<Ldim<<" "<<ib<<" "<<k<<" "<<N<<endl; 
    10671195                                Lcurr = L + j+1 + (k+1)*ldl; 
    1068                                 Bcurr = B + ib; 
     1196                                Bcurr = B + ib+1; 
    10691197                                Rcurr = Lcurr + Ldim*ldl; 
    1070                                 fgemm (F, FflasNoTrans, FflasNoTrans, M,  Ldim, N-ib-1, Mone, Bcurr, ldb, Rcurr, ldl,  one, Bcurr-Ldim, ldb); 
    1071                                 //cerr<<"j avant="<<j<<endl; 
    1072                                 //cerr<<"k, ib, j, R "<<k<<" "<<ib<<" "<<j<<" "<<R<<endl; 
    1073                                 //cerr<<"M,k="<<M<<" "<<k<<endl; 
    1074                                 //cerr<<" ftrsm with M, N="<<Ldim<<" "<<N<<endl; 
    1075                                 ftrsm (F, Side, FflasLower, FflasNoTrans, FflasUnit, M, Ldim, one, Lcurr, ldl , Bcurr-Ldim, ldb ); 
    1076                                 //cerr<<"M,k="<<M<<" "<<k<<endl; 
    1077                                 //cerr<<" fgemm with M, N, K="<<M-k<<" "<<N<<" "<<Ldim<<endl; 
     1198 
     1199                                fgemm (F, FflasNoTrans, FflasNoTrans, M,  Ldim, N-ib-1, Mone, 
     1200                                       Bcurr, ldb, Rcurr, ldl,  one, Bcurr-Ldim, ldb); 
     1201 
     1202                                ftrsm (F, Side, FflasLower, FflasNoTrans, FflasUnit, M, Ldim, one, 
     1203                                       Lcurr, ldl , Bcurr-Ldim, ldb ); 
    10781204                        } 
    10791205                } 
     
    11401266protected: 
    11411267         
    1142  
    1143         // Inversion of a lower triangular matrix with a unit diagonal 
    1144 //      template<class Field> 
    1145 //      static void  
    1146 //      invL( const Field& F, const size_t N, const typename Field::Element * L, const size_t ldl, 
    1147 //            typename Field::Element * X, const size_t ldx ){ 
    1148 //              //assumes X2 is initialized to 0 
    1149 //              typename Field::Element mone, one; 
    1150 //              F.init(one,1.0); 
    1151 //              F.init(mone,-1.0); 
    1152                  
    1153 //              if (N == 1){ 
    1154 //                      F.assign(*X, one); 
    1155 //              } 
    1156 //              else{ 
    1157 //                      size_t N1 = N >> 1; 
    1158 //                      size_t N2 = N-N1; 
    1159 //                      typename Field::Element * X11 = X; 
    1160 //                      const typename Field::Element * L11 = L; 
    1161 //                      typename Field::Element * X21 = X+N1*ldx; 
    1162 //                      const typename Field::Element * L21 = L+N1*ldl; 
    1163 //                      typename Field::Element * X22 = X21+N1; 
    1164 //                      const typename Field::Element * L22 = L21+N1; 
    1165 //                      // recursive call for X11 
    1166 //                      // X11 = L11^-1 
    1167 //                      invL( F, N1, L11, ldl, X11, ldx ); 
    1168  
    1169 //                      // recursive call for X11 
    1170 //                      // X22 = L22^-1 
    1171 //                      invL( F, N2, L22, ldl, X22, ldx ); 
    1172                          
    1173 //                      // Copy L21 into X21 
    1174 //                      for ( size_t i=0; i<N2; ++i) 
    1175 //                              fcopy( F, N1, X21+i*ldx, 1, L21+i*ldl, 1 ); 
    1176  
    1177 //                      // X21 = X21 . -X11^-1 (pascal 2004-10-12, make the negation 
    1178 //                      // after the multiplication, problem in ftrmm) 
    1179 //                      ftrmm (F, FflasRight, FflasLower, FflasNoTrans, FflasUnit,  
    1180 //                             N2, N1, mone, X11, ldx, X21, ldx ); 
    1181 // //                   for (size_t i=0; i<N2; ++i) 
    1182 // //                           for (size_t j=0; j<N1; ++j) 
    1183 // //                                   F.negin(*(X21+i*ldx+j)); 
    1184  
    1185 //                      // X21 = X22^-1 . X21 
    1186 //                      ftrmm( F, FflasLeft, FflasLower, FflasNoTrans, FflasUnit, N2, N1, one, X22, ldx, X21, ldx ); 
    1187 //              } 
    1188 //      } 
    1189                  
     1268         
    11901269        // Subroutine for Keller-Gehrig charpoly algorithm 
    11911270        // Compute the new d after a LSP ( d[i] can be zero )