MethodHMatrix.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodHMatrix.cxx 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate Data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : TMVA::MethodHMatrix                                                   *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header file for description)                          *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00017  *                                                                                *
00018  * Copyright (c) 2005:                                                            *
00019  *      CERN, Switzerland                                                         *
00020  *      U. of Victoria, Canada                                                    *
00021  *      MPI-K Heidelberg, Germany                                                 *
00022  *                                                                                *
00023  * Redistribution and use in source and binary forms, with or without             *
00024  * modification, are permitted according to the terms listed in LICENSE           *
00025  * (http://tmva.sourceforge.net/LICENSE)                                          *
00026  **********************************************************************************/
00027 
00028 #include "TMVA/ClassifierFactory.h"
00029 #include "TMVA/MethodHMatrix.h"
00030 #include "TMVA/Tools.h"
00031 #include "TMatrix.h"
00032 #include "Riostream.h"
00033 #include <algorithm>
00034 
00035 REGISTER_METHOD(HMatrix)
00036 
00037 ClassImp(TMVA::MethodHMatrix)
00038 
00039 //_______________________________________________________________________
00040 //Begin_Html
00041 /*
00042   H-Matrix method, which is implemented as a simple comparison of
00043   chi-squared estimators for signal and background, taking into
00044   account the linear correlations between the input variables
00045 
00046   This MVA approach is used by the D&#216; collaboration (FNAL) for the
00047   purpose of electron identification (see, eg.,
00048   <a href="http://arxiv.org/abs/hep-ex/9507007">hep-ex/9507007</a>).
00049   As it is implemented in TMVA, it is usually equivalent or worse than
00050   the Fisher-Mahalanobis discriminant, and it has only been added for
00051   the purpose of completeness.
00052   Two &chi;<sup>2</sup> estimators are computed for an event, each one
00053   for signal and background, using the estimates for the means and
00054   covariance matrices obtained from the training sample:<br>
00055   <center>
00056   <img vspace=6 src="gif/tmva_chi2.gif" align="bottom" >
00057   </center>
00058   TMVA then uses as normalised analyser for event (<i>i</i>) the ratio:
00059   (<i>&chi;<sub>S</sub>(i)<sup>2</sup> &minus; &chi;<sub>B</sub><sup>2</sup>(i)</i>)
00060   (<i>&chi;<sub>S</sub><sup>2</sup>(i) + &chi;<sub>B</sub><sup>2</sup>(i)</i>).
00061 */
00062 //End_Html
00063 //_______________________________________________________________________
00064 
00065 
00066 //_______________________________________________________________________
00067 TMVA::MethodHMatrix::MethodHMatrix( const TString& jobName,
00068                                     const TString& methodTitle,
00069                                     DataSetInfo& theData,
00070                                     const TString& theOption,
00071                                     TDirectory* theTargetDir )
00072    : TMVA::MethodBase( jobName, Types::kHMatrix, methodTitle, theData, theOption, theTargetDir )
00073 {
00074    // standard constructor for the H-Matrix method
00075 }
00076 
00077 //_______________________________________________________________________
00078 TMVA::MethodHMatrix::MethodHMatrix( DataSetInfo& theData,
00079                                     const TString& theWeightFile,
00080                                     TDirectory* theTargetDir )
00081    : TMVA::MethodBase( Types::kHMatrix, theData, theWeightFile, theTargetDir )
00082 {
00083    // constructor from weight file
00084 }
00085 
00086 //_______________________________________________________________________
00087 void TMVA::MethodHMatrix::Init( void )
00088 {
00089    // default initialization called by all constructors
00090 
00091    //SetNormalised( kFALSE ); obsolete!
00092 
00093    fInvHMatrixS = new TMatrixD( GetNvar(), GetNvar() );
00094    fInvHMatrixB = new TMatrixD( GetNvar(), GetNvar() );
00095    fVecMeanS    = new TVectorD( GetNvar() );
00096    fVecMeanB    = new TVectorD( GetNvar() );
00097 
00098    // the minimum requirement to declare an event signal-like
00099    SetSignalReferenceCut( 0.0 );
00100 }
00101 
00102 //_______________________________________________________________________
00103 TMVA::MethodHMatrix::~MethodHMatrix( void )
00104 {
00105    // destructor
00106    if (NULL != fInvHMatrixS) delete fInvHMatrixS;
00107    if (NULL != fInvHMatrixB) delete fInvHMatrixB;
00108    if (NULL != fVecMeanS   ) delete fVecMeanS;
00109    if (NULL != fVecMeanB   ) delete fVecMeanB;
00110 }
00111 
00112 //_______________________________________________________________________
00113 Bool_t TMVA::MethodHMatrix::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ )
00114 {
00115    // FDA can handle classification with 2 classes and regression with one regression-target
00116    if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
00117    return kFALSE;
00118 }
00119 
00120 
00121 //_______________________________________________________________________
00122 void TMVA::MethodHMatrix::DeclareOptions()
00123 {
00124    // MethodHMatrix options: none (apart from those implemented in MethodBase)
00125 }
00126 
00127 //_______________________________________________________________________
00128 void TMVA::MethodHMatrix::ProcessOptions()
00129 {
00130    // process user options
00131 }
00132 
00133 //_______________________________________________________________________
00134 void TMVA::MethodHMatrix::Train( void )
00135 {
00136    // computes H-matrices for signal and background samples
00137 
00138    // covariance matrices for signal and background
00139    ComputeCovariance( kTRUE,  fInvHMatrixS );
00140    ComputeCovariance( kFALSE, fInvHMatrixB );
00141 
00142    // sanity checks
00143    if (TMath::Abs(fInvHMatrixS->Determinant()) < 10E-24) {
00144       Log() << kWARNING << "<Train> H-matrix  S is almost singular with deterinant= "
00145             << TMath::Abs(fInvHMatrixS->Determinant())
00146             << " did you use the variables that are linear combinations or highly correlated ???"
00147             << Endl;
00148    }
00149    if (TMath::Abs(fInvHMatrixB->Determinant()) < 10E-24) {
00150       Log() << kWARNING << "<Train> H-matrix  B is almost singular with deterinant= "
00151             << TMath::Abs(fInvHMatrixB->Determinant())
00152             << " did you use the variables that are linear combinations or highly correlated ???"
00153             << Endl;
00154    }
00155 
00156     if (TMath::Abs(fInvHMatrixS->Determinant()) < 10E-120) {
00157        Log() << kFATAL << "<Train> H-matrix  S is singular with deterinant= "
00158              << TMath::Abs(fInvHMatrixS->Determinant())
00159              << " did you use the variables that are linear combinations ???"
00160              << Endl;
00161     }
00162     if (TMath::Abs(fInvHMatrixB->Determinant()) < 10E-120) {
00163        Log() << kFATAL << "<Train> H-matrix  B is singular with deterinant= "
00164              << TMath::Abs(fInvHMatrixB->Determinant())
00165              << " did you use the variables that are linear combinations ???"
00166              << Endl;
00167     }
00168 
00169    // invert matrix
00170    fInvHMatrixS->Invert();
00171    fInvHMatrixB->Invert();
00172 }
00173 
00174 //_______________________________________________________________________
00175 void TMVA::MethodHMatrix::ComputeCovariance( Bool_t isSignal, TMatrixD* mat )
00176 {
00177    // compute covariance matrix
00178 
00179    Data()->SetCurrentType(Types::kTraining);
00180 
00181    const UInt_t nvar = DataInfo().GetNVariables();
00182    UInt_t ivar, jvar;
00183 
00184    // init matrices
00185    TVectorD vec(nvar);        vec  *= 0;
00186    TMatrixD mat2(nvar, nvar); mat2 *= 0;
00187 
00188    // initialize internal sum-of-weights variables
00189    Double_t sumOfWeights = 0;
00190    Double_t *xval = new Double_t[nvar];
00191 
00192    // perform event loop
00193    for (Int_t i=0; i<Data()->GetNEvents(); i++) {
00194 
00195       // retrieve the event
00196       const Event* ev = GetEvent(i);
00197       Double_t weight = ev->GetWeight();
00198 
00199       // in case event with neg weights are to be ignored
00200       if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
00201 
00202       if (DataInfo().IsSignal(ev) != isSignal) continue;
00203 
00204       // event is of good type
00205       sumOfWeights += weight;
00206 
00207       // mean values
00208       for (ivar=0; ivar<nvar; ivar++) xval[ivar] = ev->GetValue(ivar);
00209 
00210       // covariance matrix
00211       for (ivar=0; ivar<nvar; ivar++) {
00212 
00213          vec(ivar)        += xval[ivar]*weight;
00214          mat2(ivar, ivar) += (xval[ivar]*xval[ivar])*weight;
00215 
00216          for (jvar=ivar+1; jvar<nvar; jvar++) {
00217             mat2(ivar, jvar) += (xval[ivar]*xval[jvar])*weight;
00218             mat2(jvar, ivar) = mat2(ivar, jvar); // symmetric matrix
00219          }
00220       }
00221    }
00222 
00223    // variance-covariance
00224    for (ivar=0; ivar<nvar; ivar++) {
00225 
00226       if (isSignal) (*fVecMeanS)(ivar) = vec(ivar)/sumOfWeights;
00227       else          (*fVecMeanB)(ivar) = vec(ivar)/sumOfWeights;
00228 
00229       for (jvar=0; jvar<nvar; jvar++) {
00230          (*mat)(ivar, jvar) = mat2(ivar, jvar)/sumOfWeights - vec(ivar)*vec(jvar)/(sumOfWeights*sumOfWeights);
00231       }
00232    }
00233 
00234    delete [] xval;
00235 }
00236 
00237 //_______________________________________________________________________
00238 Double_t TMVA::MethodHMatrix::GetMvaValue( Double_t* err, Double_t* errUpper )
00239 {
00240    // returns the H-matrix signal estimator
00241    Double_t s = GetChi2( Types::kSignal     );
00242    Double_t b = GetChi2( Types::kBackground );
00243   
00244    if (s+b < 0) Log() << kFATAL << "big trouble: s+b: " << s+b << Endl;
00245 
00246    // cannot determine error
00247    NoErrorCalc(err, errUpper);
00248 
00249    return (b - s)/(s + b);
00250 }
00251 
00252 //_______________________________________________________________________
00253 Double_t TMVA::MethodHMatrix::GetChi2( TMVA::Event* e,  Types::ESBType type ) const
00254 {
00255    // compute chi2-estimator for event according to type (signal/background)
00256 
00257    // loop over variables
00258    UInt_t ivar,jvar;
00259    vector<Double_t> val( GetNvar() );
00260    for (ivar=0; ivar<GetNvar(); ivar++) {
00261       val[ivar] = e->GetValue(ivar);
00262       if (IsNormalised()) val[ivar] = gTools().NormVariable( val[ivar], GetXmin( ivar ), GetXmax( ivar ) );
00263    }
00264 
00265    Double_t chi2 = 0;
00266    for (ivar=0; ivar<GetNvar(); ivar++) {
00267       for (jvar=0; jvar<GetNvar(); jvar++) {
00268          if (type == Types::kSignal) 
00269             chi2 += ( (val[ivar] - (*fVecMeanS)(ivar))*(val[jvar] - (*fVecMeanS)(jvar))
00270                       * (*fInvHMatrixS)(ivar,jvar) );
00271          else
00272             chi2 += ( (val[ivar] - (*fVecMeanB)(ivar))*(val[jvar] - (*fVecMeanB)(jvar))
00273                       * (*fInvHMatrixB)(ivar,jvar) );
00274       }
00275    }
00276 
00277    // sanity check
00278    if (chi2 < 0) Log() << kFATAL << "<GetChi2> negative chi2: " << chi2 << Endl;
00279 
00280    return chi2;
00281 }
00282 
00283 //_______________________________________________________________________
00284 Double_t TMVA::MethodHMatrix::GetChi2( Types::ESBType type ) const
00285 {
00286    // compute chi2-estimator for event according to type (signal/background)
00287 
00288    const Event * ev = GetEvent();
00289 
00290    // loop over variables
00291    UInt_t ivar,jvar;
00292    vector<Double_t> val( GetNvar() );
00293    for (ivar=0; ivar<GetNvar(); ivar++) val[ivar] = ev->GetValue( ivar );
00294 
00295    Double_t chi2 = 0;
00296    for (ivar=0; ivar<GetNvar(); ivar++) {
00297       for (jvar=0; jvar<GetNvar(); jvar++) {
00298          if (type == Types::kSignal) 
00299             chi2 += ( (val[ivar] - (*fVecMeanS)(ivar))*(val[jvar] - (*fVecMeanS)(jvar))
00300                       * (*fInvHMatrixS)(ivar,jvar) );
00301          else
00302             chi2 += ( (val[ivar] - (*fVecMeanB)(ivar))*(val[jvar] - (*fVecMeanB)(jvar))
00303                       * (*fInvHMatrixB)(ivar,jvar) );
00304       }
00305    }
00306 
00307    // sanity check
00308    if (chi2 < 0) Log() << kFATAL << "<GetChi2> negative chi2: " << chi2 << Endl;
00309 
00310    return chi2;
00311 }
00312 
00313 //_______________________________________________________________________
00314 void TMVA::MethodHMatrix::AddWeightsXMLTo( void* parent ) const {
00315    void* wght = gTools().AddChild(parent, "Weights");
00316    gTools().WriteTVectorDToXML(wght,"VecMeanS",fVecMeanS); 
00317    gTools().WriteTVectorDToXML(wght,"VecMeanB", fVecMeanB);
00318    gTools().WriteTMatrixDToXML(wght,"InvHMatS",fInvHMatrixS); 
00319    gTools().WriteTMatrixDToXML(wght,"InvHMatB",fInvHMatrixB);
00320    //Log() << kFATAL << "Please implement writing of weights as XML" << Endl;
00321 }
00322 
00323 void TMVA::MethodHMatrix::ReadWeightsFromXML( void* wghtnode ){
00324    void* descnode = gTools().GetChild(wghtnode);
00325    gTools().ReadTVectorDFromXML(descnode,"VecMeanS",fVecMeanS);
00326    descnode = gTools().GetNextChild(descnode);
00327    gTools().ReadTVectorDFromXML(descnode,"VecMeanB", fVecMeanB);
00328    descnode = gTools().GetNextChild(descnode);
00329    gTools().ReadTMatrixDFromXML(descnode,"InvHMatS",fInvHMatrixS); 
00330    descnode = gTools().GetNextChild(descnode);
00331    gTools().ReadTMatrixDFromXML(descnode,"InvHMatB",fInvHMatrixB);
00332 }
00333 
00334 //_______________________________________________________________________
00335 void  TMVA::MethodHMatrix::ReadWeightsFromStream( istream& istr )
00336 {
00337    // read variable names and min/max
00338    // NOTE: the latter values are mandatory for the normalisation 
00339    // in the reader application !!!
00340    UInt_t ivar,jvar;
00341    TString var, dummy;
00342    istr >> dummy;
00343    //this->SetMethodName(dummy);
00344 
00345    // mean vectors
00346    for (ivar=0; ivar<GetNvar(); ivar++) 
00347       istr >> (*fVecMeanS)(ivar) >> (*fVecMeanB)(ivar);
00348 
00349    // inverse covariance matrices (signal)
00350    for (ivar=0; ivar<GetNvar(); ivar++) 
00351       for (jvar=0; jvar<GetNvar(); jvar++) 
00352          istr >> (*fInvHMatrixS)(ivar,jvar);
00353 
00354    // inverse covariance matrices (background)
00355    for (ivar=0; ivar<GetNvar(); ivar++) 
00356       for (jvar=0; jvar<GetNvar(); jvar++) 
00357          istr >> (*fInvHMatrixB)(ivar,jvar);
00358 }
00359 
00360 //_______________________________________________________________________
00361 void TMVA::MethodHMatrix::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00362 {
00363    // write Fisher-specific classifier response
00364    fout << "   // arrays of input evt vs. variable " << endl;
00365    fout << "   double fInvHMatrixS[" << GetNvar() << "][" << GetNvar() << "]; // inverse H-matrix (signal)" << endl;
00366    fout << "   double fInvHMatrixB[" << GetNvar() << "][" << GetNvar() << "]; // inverse H-matrix (background)" << endl;
00367    fout << "   double fVecMeanS[" << GetNvar() << "];    // vector of mean values (signal)" << endl;
00368    fout << "   double fVecMeanB[" << GetNvar() << "];    // vector of mean values (background)" << endl;
00369    fout << "   " << endl;
00370    fout << "   double GetChi2( const std::vector<double>& inputValues, int type ) const;" << endl;
00371    fout << "};" << endl;
00372    fout << "   " << endl;
00373    fout << "void " << className << "::Initialize() " << endl;
00374    fout << "{" << endl;
00375    fout << "   // init vectors with mean values" << endl;
00376    for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00377       fout << "   fVecMeanS[" << ivar << "] = " << (*fVecMeanS)(ivar) << ";" << endl;
00378       fout << "   fVecMeanB[" << ivar << "] = " << (*fVecMeanB)(ivar) << ";" << endl;
00379    }
00380    fout << "   " << endl;
00381    fout << "   // init H-matrices" << endl;
00382    for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00383       for (UInt_t jvar=0; jvar<GetNvar(); jvar++) {
00384          fout << "   fInvHMatrixS[" << ivar << "][" << jvar << "] = " 
00385               << (*fInvHMatrixS)(ivar,jvar) << ";" << endl;
00386          fout << "   fInvHMatrixB[" << ivar << "][" << jvar << "] = " 
00387               << (*fInvHMatrixB)(ivar,jvar) << ";" << endl;
00388       }
00389    }
00390    fout << "}" << endl;
00391    fout << "   " << endl;
00392    fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << endl;
00393    fout << "{" << endl;
00394    fout << "   // returns the H-matrix signal estimator" << endl;
00395    fout << "   double s = GetChi2( inputValues, " << Types::kSignal << " );" << endl;
00396    fout << "   double b = GetChi2( inputValues, " << Types::kBackground << " );" << endl;
00397    fout << "   " << endl;
00398    fout << "   if (s+b <= 0) std::cout << \"Problem in class " << className << "::GetMvaValue__: s+b = \"" << endl;
00399    fout << "                           << s+b << \" <= 0 \"  << std::endl;" << endl;
00400    fout << "   " << endl;
00401    fout << "   return (b - s)/(s + b);" << endl;
00402    fout << "}" << endl;
00403    fout << "   " << endl;
00404    fout << "inline double " << className << "::GetChi2( const std::vector<double>& inputValues, int type ) const" << endl;
00405    fout << "{" << endl;
00406    fout << "   // compute chi2-estimator for event according to type (signal/background)" << endl;
00407    fout << "   " << endl;
00408    fout << "   size_t ivar,jvar;" << endl;
00409    fout << "   double chi2 = 0;" << endl;
00410    fout << "   for (ivar=0; ivar<GetNvar(); ivar++) {" << endl;
00411    fout << "      for (jvar=0; jvar<GetNvar(); jvar++) {" << endl;
00412    fout << "         if (type == " << Types::kSignal << ") " << endl;
00413    fout << "            chi2 += ( (inputValues[ivar] - fVecMeanS[ivar])*(inputValues[jvar] - fVecMeanS[jvar])" << endl;
00414    fout << "                      * fInvHMatrixS[ivar][jvar] );" << endl;
00415    fout << "         else" << endl;
00416    fout << "            chi2 += ( (inputValues[ivar] - fVecMeanB[ivar])*(inputValues[jvar] - fVecMeanB[jvar])" << endl;
00417    fout << "                      * fInvHMatrixB[ivar][jvar] );" << endl;
00418    fout << "      }" << endl;
00419    fout << "   }   // loop over variables   " << endl;
00420    fout << "   " << endl;
00421    fout << "   // sanity check" << endl;
00422    fout << "   if (chi2 < 0) std::cout << \"Problem in class " << className << "::GetChi2: chi2 = \"" << endl;
00423    fout << "                           << chi2 << \" < 0 \"  << std::endl;" << endl;
00424    fout << "   " << endl;
00425    fout << "   return chi2;" << endl;
00426    fout << "}" << endl;
00427    fout << "   " << endl;
00428    fout << "// Clean up" << endl;
00429    fout << "inline void " << className << "::Clear() " << endl;
00430    fout << "{" << endl;
00431    fout << "   // nothing to clear" << endl;
00432    fout << "}" << endl;
00433 }
00434 
00435 //_______________________________________________________________________
00436 void TMVA::MethodHMatrix::GetHelpMessage() const
00437 {
00438    // get help message text
00439    //
00440    // typical length of text line: 
00441    //         "|--------------------------------------------------------------|"
00442    Log() << Endl;
00443    Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00444    Log() << Endl;
00445    Log() << "The H-Matrix classifier discriminates one class (signal) of a feature" << Endl;
00446    Log() << "vector from another (background). The correlated elements of the" << Endl;
00447    Log() << "vector are assumed to be Gaussian distributed, and the inverse of" << Endl;
00448    Log() << "the covariance matrix is the H-Matrix. A multivariate chi-squared" << Endl;
00449    Log() << "estimator is built that exploits differences in the mean values of" << Endl;
00450    Log() << "the vector elements between the two classes for the purpose of" << Endl;
00451    Log() << "discrimination." << Endl;
00452    Log() << Endl;
00453    Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
00454    Log() << Endl;
00455    Log() << "The TMVA implementation of the H-Matrix classifier has been shown" << Endl;
00456    Log() << "to underperform in comparison with the corresponding Fisher discriminant," << Endl;
00457    Log() << "when using similar assumptions and complexity. Its use is therefore" << Endl;
00458    Log() << "depreciated. Only in cases where the background model is strongly" << Endl;
00459    Log() << "non-Gaussian, H-Matrix may perform better than Fisher. In such" << Endl;
00460    Log() << "occurrences the user is advised to employ non-linear classifiers. " << Endl;
00461    Log() << Endl;
00462    Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
00463    Log() << Endl;
00464    Log() << "None" << Endl;
00465 }

Generated on Tue Jul 5 15:25:03 2011 for ROOT_528-00b_version by  doxygen 1.5.1