Changeset 21

Show
Ignore:
Timestamp:
05/01/07 19:23:34 (2 years ago)
Author:
pernet
Message:

Introduction of the new automatic choice of underlying BLAS, for any finite finite field implementation:
float/double is chosen depending on the prime, the dimension, and efficiency considerations.

Files:
1 added
12 modified

Legend:

Unmodified
Added
Removed
  • include/fflas-ffpack/fflas.h

    r18 r21  
    3232         
    3333#ifndef __LINBOX_STRASSEN_OPTIMIZATION 
    34 #define WINOTHRESHOLD 750 
     34#define WINOTHRESHOLD 400 
    3535#else 
    3636#define WINOTHRESHOLD __LINBOX_WINOTHRESHOLD 
    3737#endif 
    3838 
     39// Thresholds determining which floating point representation to use, 
     40// depending on the cardinality of the finite field. This is only used when 
     41// the element representation is not a floating point type. 
     42#define FLOAT_DOUBLE_THRESHOLD_0 430 
     43#define FLOAT_DOUBLE_THRESHOLD_1 350 
     44#define FLOAT_DOUBLE_THRESHOLD_2 175 
     45         
    3946#define DOUBLE_MANTISSA 53 
    4047#define FLOAT_MANTISSA 24 
     
    4754        enum FFLAS_DIAG      { FflasNonUnit=131, FflasUnit=132 }; 
    4855        enum FFLAS_SIDE      { FflasLeft=141, FflasRight = 142 }; 
    49          
     56 
     57        /* Determine the type of the element representation for Matrix Mult kernel 
     58         * FflasDouble: to use the double precision BLAS 
     59         * FflasFloat: to use the single precison BLAS 
     60         * FflasFloat: for any other domain, that can not be converted to floating point integers 
     61         */ 
     62        enum FFLAS_BASE      { FflasDouble = 151, FflasFloat = 152, FflasGeneric = 153}; 
     63 
     64        /* Representations of Z with floating point elements*/ 
    5065        typedef UnparametricField<float> FloatDomain; 
    51  
    5266        typedef UnparametricField<double> DoubleDomain; 
    5367 
     
    132146        template<class Field> 
    133147        static void 
    134         fgemv (const Field& F, const enum FFLAS_TRANSPOSE TransA,  
     148        fgemv (const Field& F, const FFLAS_TRANSPOSE TransA,  
    135149               const size_t M, const size_t N, 
    136150               const typename Field::Element alpha,  
     
    161175        template<class Field> 
    162176        static void 
    163         ftrsv (const Field& F, const enum FFLAS_UPLO Uplo,  
    164                const enum FFLAS_TRANSPOSE TransA, const enum FFLAS_DIAG Diag, 
     177        ftrsv (const Field& F, const FFLAS_UPLO Uplo,  
     178               const FFLAS_TRANSPOSE TransA, const FFLAS_DIAG Diag, 
    165179               const size_t N,const typename Field::Element * A, const size_t lda, 
    166180               typename Field::Element * X, int incX); 
     
    177191        template<class Field> 
    178192        static void 
    179         ftrsm (const Field& F, const enum FFLAS_SIDE Side, 
    180                const enum FFLAS_UPLO Uplo,  
    181                const enum FFLAS_TRANSPOSE TransA, 
    182                const enum FFLAS_DIAG Diag,  
     193        ftrsm (const Field& F, const FFLAS_SIDE Side, 
     194               const FFLAS_UPLO Uplo,  
     195               const FFLAS_TRANSPOSE TransA, 
     196               const FFLAS_DIAG Diag,  
    183197               const size_t M, const size_t N, 
    184198               const typename Field::Element alpha, 
     
    193207        template<class Field> 
    194208        static void 
    195         ftrmm (const Field& F, const enum FFLAS_SIDE Side, 
    196                const enum FFLAS_UPLO Uplo,  
    197                const enum FFLAS_TRANSPOSE TransA, 
    198                const enum FFLAS_DIAG Diag,  
     209        ftrmm (const Field& F, const FFLAS_SIDE Side, 
     210               const FFLAS_UPLO Uplo,  
     211               const FFLAS_TRANSPOSE TransA, 
     212               const FFLAS_DIAG Diag,  
    199213               const size_t M, const size_t N, 
    200214               const typename Field::Element alpha, 
     
    211225        static typename Field::Element*  
    212226        fgemm( const Field& F, 
    213                const enum FFLAS_TRANSPOSE ta, 
    214                const enum FFLAS_TRANSPOSE tb, 
     227               const FFLAS_TRANSPOSE ta, 
     228               const FFLAS_TRANSPOSE tb, 
    215229               const size_t m, 
    216230               const size_t n, 
     
    221235               const typename Field::Element beta, 
    222236               typename Field::Element* C, const size_t ldc, 
    223                const size_t wl); 
     237               const size_t w){ 
     238 
     239                if (!(m && n && k)) return C; 
     240                 
     241                FFLAS_BASE base = BaseCompute (F, w); 
     242 
     243                WinoMain (F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, 
     244                                 C, ldc, DotProdBound (F, w, beta, base), w, base); 
     245                return C; 
     246                }; 
    224247         
    225248        /** @brief  Field GEneral Matrix Multiply  
     
    232255        static typename Field::Element* 
    233256        fgemm (const Field& F, 
    234                const enum FFLAS_TRANSPOSE ta, 
    235                const enum FFLAS_TRANSPOSE tb, 
     257               const FFLAS_TRANSPOSE ta, 
     258               const FFLAS_TRANSPOSE tb, 
    236259               const size_t m, 
    237260               const size_t n, 
     
    245268               typename Field::Element* C,  
    246269               const size_t ldc){ 
    247                 size_t ws =0; 
    248                 if ( (ta==FflasNoTrans)  && (tb==FflasNoTrans)) { 
    249                         size_t kt = MIN(MIN(k,m),n); 
    250                         while (kt >= WINOTHRESHOLD){ 
    251                                 ws++; 
    252                                 kt/=2; 
    253                         } 
    254                 } 
    255                 return fgemm(F, ta, tb, m, n, k, alpha, A, lda, B, ldb, 
    256                              beta, C, ldc, ws); 
     270 
     271                if (!(m && n && k)) return C; 
     272 
     273                size_t w, kmax=0; 
     274                FFLAS_BASE base; 
     275 
     276                setMatMulParam<typename Field::Element> ()(F, MIN(MIN(m,n),k), beta, 
     277                                                           w, base, kmax); 
     278 
     279                WinoMain (F, ta, tb, m, n, k, alpha, A, lda, B, ldb, beta, 
     280                          C, ldc, kmax, w, base); 
     281                return C; 
    257282        } 
    258283         
     
    265290        template<class Field> 
    266291        static typename Field::Element* fsquare (const Field& F, 
    267                                                  const enum FFLAS_TRANSPOSE ta, 
     292                                                 const FFLAS_TRANSPOSE ta, 
    268293                                                 const size_t n, 
    269294                                                 const typename Field::Element alpha, 
     
    417442        template <class Field> 
    418443        static size_t DotProdBound (const Field& F, const size_t w, 
    419                                     const typename Field::Element& beta); 
     444                                    const typename Field::Element& beta, 
     445                                    const FFLAS_BASE base); 
    420446 
    421447        template <class Field> 
    422448        static size_t DotProdBoundCompute (const Field& F, const size_t w, 
    423                                            const typename Field::Element& beta); 
    424          
    425         template <class Element> 
    426         class callDotProdBoundCompute; 
     449                                           const typename Field::Element& beta, 
     450                                           const FFLAS_BASE base); 
     451         
     452 
     453        template <class Field> 
     454        static FFLAS_BASE BaseCompute (const Field& F, const size_t w); 
     455         
     456        static size_t WinoSteps (const size_t m); 
     457         
     458        //      template <class Element> 
     459//      class callDotProdBoundCompute; 
    427460 
    428461        /** @brief Bound for the delayed modulus triangular system solving 
     
    441474        class callTRSMBound; 
    442475 
     476        /** @brief Set the optimal parameters for the Matrix Multiplication 
     477         */ 
     478        template <class Element> 
     479        class setMatMulParam; 
     480 
    443481        template <class Field> 
    444482        static void DynamicPealing( const Field& F,  
    445                                     const enum FFLAS_TRANSPOSE ta, 
    446                                     const enum FFLAS_TRANSPOSE tb, 
     483                                    const FFLAS_TRANSPOSE ta, 
     484                                    const FFLAS_TRANSPOSE tb, 
    447485                                    const size_t m, const size_t n, const size_t k, 
    448486                                    const typename Field::Element alpha,  
     
    455493        template<class Field> 
    456494        static void MatVectProd (const Field& F,  
    457                                  const enum FFLAS_TRANSPOSE TransA,  
     495                                 const FFLAS_TRANSPOSE TransA,  
    458496                                 const size_t M, const size_t N, 
    459497                                 const typename Field::Element alpha,  
     
    465503        template <class Field> 
    466504        static void ClassicMatmul(const Field& F,   
    467                                   const enum FFLAS_TRANSPOSE ta, 
    468                                   const enum FFLAS_TRANSPOSE tb, 
     505                                  const FFLAS_TRANSPOSE ta, 
     506                                  const FFLAS_TRANSPOSE tb, 
    469507                                  const size_t m, const size_t n, const size_t k, 
    470508                                  const typename Field::Element alpha, 
     
    473511                                  const typename Field::Element beta, 
    474512                                  typename Field::Element * C, const size_t ldc,  
    475                                   const size_t kmax ); 
     513                                  const size_t kmax, const FFLAS_BASE base ); 
    476514     
    477515        // Winograd Multiplication  alpha.A(n*k) * B(k*m) + beta . C(n*m) 
     
    479517        template<class Field> 
    480518        static void WinoCalc (const Field& F,  
    481                               const enum FFLAS_TRANSPOSE ta, 
    482                               const enum FFLAS_TRANSPOSE tb, 
     519                              const FFLAS_TRANSPOSE ta, 
     520                              const FFLAS_TRANSPOSE tb, 
    483521                              const size_t mr, const size_t nr,const size_t kr, 
    484522                              const typename Field::Element alpha, 
     
    487525                              const typename Field::Element beta, 
    488526                              typename Field::Element * C, const size_t ldc, 
    489                               const size_t kmax, const size_t w); 
     527                              const size_t kmax, const size_t w, const FFLAS_BASE base); 
    490528         
    491529        template<class Field> 
    492530        static void WinoMain (const Field& F,  
    493                               const enum FFLAS_TRANSPOSE ta, 
    494                               const enum FFLAS_TRANSPOSE tb, 
     531                              const FFLAS_TRANSPOSE ta, 
     532                              const FFLAS_TRANSPOSE tb, 
    495533                              const size_t m, const size_t n, const size_t k, 
    496534                              const typename Field::Element alpha, 
     
    499537                              const typename Field::Element beta, 
    500538                              typename Field::Element * C, const size_t ldc, 
    501                               const size_t kmax, const size_t w); 
     539                              const size_t kmax, const size_t w, const FFLAS_BASE base); 
    502540 
    503541 
     
    516554        // Specialized routines for ftrsm 
    517555        template<class Field> 
    518         static void ftrsmLeftUpNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     556        static void ftrsmLeftUpNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    519557                                        const size_t M, const size_t N, 
    520558                                        const typename Field::Element alpha, 
     
    523561         
    524562        template<class Field> 
    525         static void ftrsmLeftUpTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     563        static void ftrsmLeftUpTrans (const Field& F, const FFLAS_DIAG Diag,  
    526564                                      const size_t M, const size_t N, 
    527565                                      const typename Field::Element alpha, 
     
    530568         
    531569        template<class Field> 
    532         static void ftrsmLeftLowNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     570        static void ftrsmLeftLowNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    533571                                         const size_t M, const size_t N, 
    534572                                         const typename Field::Element alpha, 
     
    567605         
    568606        template<class Field> 
    569         static void ftrsmLeftLowTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     607        static void ftrsmLeftLowTrans (const Field& F, const FFLAS_DIAG Diag,  
    570608                                       const size_t M, const size_t N, 
    571609                                       const typename Field::Element alpha, 
     
    574612         
    575613        template<class Field> 
    576         static void ftrsmRightUpNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     614        static void ftrsmRightUpNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    577615                                         const size_t M, const size_t N, 
    578616                                         const typename Field::Element alpha, 
     
    582620         
    583621        template<class Field> 
    584         static void ftrsmRightUpTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     622        static void ftrsmRightUpTrans (const Field& F, const FFLAS_DIAG Diag,  
    585623                                       const size_t M, const size_t N, 
    586624                                       const typename Field::Element alpha, 
     
    589627 
    590628        template<class Field> 
    591         static void ftrsmRightLowNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     629        static void ftrsmRightLowNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    592630                                          const size_t M, const size_t N, 
    593631                                          const typename Field::Element alpha, 
     
    596634 
    597635        template<class Field> 
    598         static void ftrsmRightLowTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     636        static void ftrsmRightLowTrans (const Field& F, const FFLAS_DIAG Diag,  
    599637                                        const size_t M, const size_t N, 
    600638                                        const typename Field::Element alpha, 
     
    604642        // Specialized routines for ftrmm 
    605643        template<class Field> 
    606         static void ftrmmLeftUpNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     644        static void ftrmmLeftUpNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    607645                                        const size_t M, const size_t N, 
    608646                                        const typename Field::Element * A, const size_t lda, 
     
    611649 
    612650        template<class Field> 
    613         static void ftrmmLeftUpTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     651        static void ftrmmLeftUpTrans (const Field& F, const FFLAS_DIAG Diag,  
    614652                                      const size_t M, const size_t N, 
    615653                                      const typename Field::Element * A, const size_t lda, 
     
    618656 
    619657        template<class Field> 
    620         static void ftrmmLeftLowNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     658        static void ftrmmLeftLowNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    621659                                         const size_t M, const size_t N, 
    622660                                         const typename Field::Element * A, const size_t lda, 
     
    625663 
    626664        template<class Field> 
    627         static void ftrmmLeftLowTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     665        static void ftrmmLeftLowTrans (const Field& F, const FFLAS_DIAG Diag,  
    628666                                       const size_t M, const size_t N, 
    629667                                       const typename Field::Element * A, const size_t lda, 
     
    632670         
    633671        template<class Field> 
    634         static void ftrmmRightUpNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     672        static void ftrmmRightUpNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    635673                                         const size_t M, const size_t N, 
    636674                                         const typename Field::Element * A, const size_t lda, 
     
    639677 
    640678        template<class Field> 
    641         static void ftrmmRightUpTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     679        static void ftrmmRightUpTrans (const Field& F, const FFLAS_DIAG Diag,  
    642680                                       const size_t M, const size_t N, 
    643681                                       const typename Field::Element * A, const size_t lda, 
     
    646684 
    647685        template<class Field> 
    648         static void ftrmmRightLowNoTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     686        static void ftrmmRightLowNoTrans (const Field& F, const FFLAS_DIAG Diag,  
    649687                                          const size_t M, const size_t N, 
    650688                                          const typename Field::Element * A, const size_t lda, 
     
    652690                                          const size_t nmax); 
    653691        template<class Field> 
    654         static void ftrmmRightLowTrans (const Field& F, const enum FFLAS_DIAG Diag,  
     692        static void ftrmmRightLowTrans (const Field& F, const FFLAS_DIAG Diag,  
    655693                                        const size_t M, const size_t N, 
    656694                                        const typename Field::Element * A, const size_t lda, 
  • include/fflas-ffpack/fflas_bounds.inl

    r18 r21  
    1919// Computes the maximal dimension k so that the product A*B+beta.C over Z, 
    2020// where A is m*k and B is k*n can be performed correctly with w Winograd 
    21 // recursion levels on the 53 bits of double mantissa 
     21// recursion levels on the number of bits of the floating point mantissa 
    2222//--------------------------------------------------------------------- 
    2323template  <class Field>  
    2424inline size_t FFLAS::DotProdBoundCompute (const Field& F, const size_t w,  
    25                                           const typename Field::Element& beta){ 
    26         return callDotProdBoundCompute<typename Field::Element>() (F, w, beta); 
    27 } 
    28  
    29 template<class Element> 
    30 class FFLAS::callDotProdBoundCompute { 
    31 public: 
    32         template  <class Field>  
    33         size_t operator () (const Field& F, const size_t w,  
    34                             const typename Field::Element& beta) 
    35         { 
    36                 typename Field::Element mone; 
    37                 static FFLAS_INT_TYPE p; 
    38                 F.characteristic(p); 
    39                 F.init (mone, -1.0); 
    40                 size_t kmax; 
    41                 if (p == 0) 
    42                         kmax = 2; 
    43                 else 
    44                         if (w > 0) { 
    45                                 size_t ex=1; 
    46                                 for (size_t i=0; i < w; ++i)    ex *= 3; 
    47                                 //FFLAS_INT_TYPE c = (p-1)*(ex)/2; //bound for a centered representation 
    48                                 long long c = (p-1)*(1+ex)/2; 
    49                                 kmax =  lround(( double(1ULL << DOUBLE_MANTISSA) /double(c*c) + 1)*(1 << w)); 
    50                                 if (kmax ==  ( 1ULL << w)) 
    51                                         kmax = 2; 
    52                         } 
    53                         else{ 
    54                                 long long c = p-1; 
    55                                 long long cplt=0; 
    56                                 if (!F.isZero (beta)) 
    57                                         if (F.isOne (beta) || F.areEqual (beta, mone)) 
    58                                                 cplt = c; 
    59                                         else cplt = c*c; 
    60                                 kmax =  lround(( double((1ULL << DOUBLE_MANTISSA) - cplt)) /double(c*c)); 
    61                                 if (kmax  < 2) 
    62                                         kmax = 2; 
    63                         } 
    64                 return  MIN(kmax,1ULL<<31); 
    65         } 
    66 }; 
    67  
    68 template  <>  
    69 class FFLAS::callDotProdBoundCompute<double> { 
    70 public: 
    71         template <class Field> 
    72         size_t operator() (const Field& F, const size_t w,  
    73                            const double& beta) 
    74         { 
    75                 double mone; 
    76                 static FFLAS_INT_TYPE p; 
    77                 F.characteristic(p); 
    78                 F.init (mone, -1.0); 
    79                 size_t  kmax; 
    80                 if (p == 0) 
    81                         kmax = 2; 
    82                 else 
    83                         if (w > 0) { 
    84                                 size_t ex=1; 
    85                                 for (size_t i=0; i < w; ++i)    ex *= 3; 
    86                                 long long c; 
     25                                          const typename Field::Element& beta, 
     26                                          const FFLAS_BASE base){ 
     27         
     28        typename Field::Element mone; 
     29        static FFLAS_INT_TYPE p; 
     30        F.characteristic(p); 
     31        F.init (mone, -1.0); 
     32        size_t kmax; 
     33 
     34        unsigned long mantissa = (base == FflasDouble)? DOUBLE_MANTISSA : FLOAT_MANTISSA; 
     35         
     36        if (p == 0) 
     37                kmax = 2; 
     38        else 
     39                if (w > 0) { 
     40                        size_t ex=1; 
     41                        for (size_t i=0; i < w; ++i)    ex *= 3; 
     42                        //FFLAS_INT_TYPE c = (p-1)*(ex)/2; //bound for a centered representation 
     43                        long long c; 
    8744#ifndef _LINBOX_CONFIG_H 
    88                                 if (F.balanced) 
    89                                         c = (p-1)*(ex)/2; // balanced representation 
    90                                 else 
     45                        if (F.balanced) 
     46                                c = (p-1)*(ex)/2; // balanced representation 
     47                        else 
    9148#endif 
    92                                         c = (p-1)*(1+ex)/2; // positive representation 
    93                                 kmax =  lround(double(1ULL << DOUBLE_MANTISSA) /(double(c*c) + 1)*(1ULL << w)); 
    94                                 if (kmax ==  ( 1ULL << w)) 
    95                                         kmax = 2; 
    96                         } 
    97                         else{ 
    98                                 long long c = p-1; 
    99                                 long long cplt=0; 
    100                                 if (!F.isZero (beta)) 
    101                                         if (F.isOne (beta) || F.areEqual (beta, mone)) 
    102                                                 cplt = c; 
    103                                         else cplt = c*c; 
    104                                 kmax =  lround( double((1ULL << DOUBLE_MANTISSA) - cplt) /(double(c*c))); 
    105                                 if (kmax  < 2) 
    106                                         kmax = 2; 
    107                         } 
    108                 return (size_t) MIN(kmax,1ULL<<31); 
    109         } 
    110 }; 
    111  
    112 template  <>  
    113 class FFLAS::callDotProdBoundCompute<float> { 
    114 public: 
    115         template <class Field> 
    116         size_t operator () (const Field& F, const size_t w,  
    117                             const float& beta) 
    118         { 
    119                 float mone,one; 
    120                 static FFLAS_INT_TYPE p; 
    121                 F.characteristic(p); 
    122                 F.init (one, 1.0F); 
    123                 F.neg(mone,one); 
    124                 size_t  kmax; 
    125                 if (p == 0) 
    126                         kmax = 2; 
    127                 else 
    128                         if (w > 0) { 
    129                                 size_t ex=1; 
    130                                 for (size_t i=0; i < w; ++i)  &n