00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include "TMVA/ClassifierFactory.h"
00029 #include "TMVA/MethodHMatrix.h"
00030 #include "TMVA/Tools.h"
00031 #include "TMatrix.h"
00032 #include "Riostream.h"
00033 #include <algorithm>
00034
00035 REGISTER_METHOD(HMatrix)
00036
00037 ClassImp(TMVA::MethodHMatrix)
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067 TMVA::MethodHMatrix::MethodHMatrix( const TString& jobName,
00068 const TString& methodTitle,
00069 DataSetInfo& theData,
00070 const TString& theOption,
00071 TDirectory* theTargetDir )
00072 : TMVA::MethodBase( jobName, Types::kHMatrix, methodTitle, theData, theOption, theTargetDir )
00073 {
00074
00075 }
00076
00077
00078 TMVA::MethodHMatrix::MethodHMatrix( DataSetInfo& theData,
00079 const TString& theWeightFile,
00080 TDirectory* theTargetDir )
00081 : TMVA::MethodBase( Types::kHMatrix, theData, theWeightFile, theTargetDir )
00082 {
00083
00084 }
00085
00086
00087 void TMVA::MethodHMatrix::Init( void )
00088 {
00089
00090
00091
00092
00093 fInvHMatrixS = new TMatrixD( GetNvar(), GetNvar() );
00094 fInvHMatrixB = new TMatrixD( GetNvar(), GetNvar() );
00095 fVecMeanS = new TVectorD( GetNvar() );
00096 fVecMeanB = new TVectorD( GetNvar() );
00097
00098
00099 SetSignalReferenceCut( 0.0 );
00100 }
00101
00102
00103 TMVA::MethodHMatrix::~MethodHMatrix( void )
00104 {
00105
00106 if (NULL != fInvHMatrixS) delete fInvHMatrixS;
00107 if (NULL != fInvHMatrixB) delete fInvHMatrixB;
00108 if (NULL != fVecMeanS ) delete fVecMeanS;
00109 if (NULL != fVecMeanB ) delete fVecMeanB;
00110 }
00111
00112
00113 Bool_t TMVA::MethodHMatrix::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
00114 {
00115
00116 if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
00117 return kFALSE;
00118 }
00119
00120
00121
00122 void TMVA::MethodHMatrix::DeclareOptions()
00123 {
00124
00125 }
00126
00127
00128 void TMVA::MethodHMatrix::ProcessOptions()
00129 {
00130
00131 }
00132
00133
00134 void TMVA::MethodHMatrix::Train( void )
00135 {
00136
00137
00138
00139 ComputeCovariance( kTRUE, fInvHMatrixS );
00140 ComputeCovariance( kFALSE, fInvHMatrixB );
00141
00142
00143 if (TMath::Abs(fInvHMatrixS->Determinant()) < 10E-24) {
00144 Log() << kWARNING << "<Train> H-matrix S is almost singular with deterinant= "
00145 << TMath::Abs(fInvHMatrixS->Determinant())
00146 << " did you use the variables that are linear combinations or highly correlated ???"
00147 << Endl;
00148 }
00149 if (TMath::Abs(fInvHMatrixB->Determinant()) < 10E-24) {
00150 Log() << kWARNING << "<Train> H-matrix B is almost singular with deterinant= "
00151 << TMath::Abs(fInvHMatrixB->Determinant())
00152 << " did you use the variables that are linear combinations or highly correlated ???"
00153 << Endl;
00154 }
00155
00156 if (TMath::Abs(fInvHMatrixS->Determinant()) < 10E-120) {
00157 Log() << kFATAL << "<Train> H-matrix S is singular with deterinant= "
00158 << TMath::Abs(fInvHMatrixS->Determinant())
00159 << " did you use the variables that are linear combinations ???"
00160 << Endl;
00161 }
00162 if (TMath::Abs(fInvHMatrixB->Determinant()) < 10E-120) {
00163 Log() << kFATAL << "<Train> H-matrix B is singular with deterinant= "
00164 << TMath::Abs(fInvHMatrixB->Determinant())
00165 << " did you use the variables that are linear combinations ???"
00166 << Endl;
00167 }
00168
00169
00170 fInvHMatrixS->Invert();
00171 fInvHMatrixB->Invert();
00172 }
00173
00174
00175 void TMVA::MethodHMatrix::ComputeCovariance( Bool_t isSignal, TMatrixD* mat )
00176 {
00177
00178
00179 Data()->SetCurrentType(Types::kTraining);
00180
00181 const UInt_t nvar = DataInfo().GetNVariables();
00182 UInt_t ivar, jvar;
00183
00184
00185 TVectorD vec(nvar); vec *= 0;
00186 TMatrixD mat2(nvar, nvar); mat2 *= 0;
00187
00188
00189 Double_t sumOfWeights = 0;
00190 Double_t *xval = new Double_t[nvar];
00191
00192
00193 for (Int_t i=0; i<Data()->GetNEvents(); i++) {
00194
00195
00196 const Event* ev = GetEvent(i);
00197 Double_t weight = ev->GetWeight();
00198
00199
00200 if (IgnoreEventsWithNegWeightsInTraining() && weight <= 0) continue;
00201
00202 if (DataInfo().IsSignal(ev) != isSignal) continue;
00203
00204
00205 sumOfWeights += weight;
00206
00207
00208 for (ivar=0; ivar<nvar; ivar++) xval[ivar] = ev->GetValue(ivar);
00209
00210
00211 for (ivar=0; ivar<nvar; ivar++) {
00212
00213 vec(ivar) += xval[ivar]*weight;
00214 mat2(ivar, ivar) += (xval[ivar]*xval[ivar])*weight;
00215
00216 for (jvar=ivar+1; jvar<nvar; jvar++) {
00217 mat2(ivar, jvar) += (xval[ivar]*xval[jvar])*weight;
00218 mat2(jvar, ivar) = mat2(ivar, jvar);
00219 }
00220 }
00221 }
00222
00223
00224 for (ivar=0; ivar<nvar; ivar++) {
00225
00226 if (isSignal) (*fVecMeanS)(ivar) = vec(ivar)/sumOfWeights;
00227 else (*fVecMeanB)(ivar) = vec(ivar)/sumOfWeights;
00228
00229 for (jvar=0; jvar<nvar; jvar++) {
00230 (*mat)(ivar, jvar) = mat2(ivar, jvar)/sumOfWeights - vec(ivar)*vec(jvar)/(sumOfWeights*sumOfWeights);
00231 }
00232 }
00233
00234 delete [] xval;
00235 }
00236
00237
00238 Double_t TMVA::MethodHMatrix::GetMvaValue( Double_t* err, Double_t* errUpper )
00239 {
00240
00241 Double_t s = GetChi2( Types::kSignal );
00242 Double_t b = GetChi2( Types::kBackground );
00243
00244 if (s+b < 0) Log() << kFATAL << "big trouble: s+b: " << s+b << Endl;
00245
00246
00247 NoErrorCalc(err, errUpper);
00248
00249 return (b - s)/(s + b);
00250 }
00251
00252
00253 Double_t TMVA::MethodHMatrix::GetChi2( TMVA::Event* e, Types::ESBType type ) const
00254 {
00255
00256
00257
00258 UInt_t ivar,jvar;
00259 vector<Double_t> val( GetNvar() );
00260 for (ivar=0; ivar<GetNvar(); ivar++) {
00261 val[ivar] = e->GetValue(ivar);
00262 if (IsNormalised()) val[ivar] = gTools().NormVariable( val[ivar], GetXmin( ivar ), GetXmax( ivar ) );
00263 }
00264
00265 Double_t chi2 = 0;
00266 for (ivar=0; ivar<GetNvar(); ivar++) {
00267 for (jvar=0; jvar<GetNvar(); jvar++) {
00268 if (type == Types::kSignal)
00269 chi2 += ( (val[ivar] - (*fVecMeanS)(ivar))*(val[jvar] - (*fVecMeanS)(jvar))
00270 * (*fInvHMatrixS)(ivar,jvar) );
00271 else
00272 chi2 += ( (val[ivar] - (*fVecMeanB)(ivar))*(val[jvar] - (*fVecMeanB)(jvar))
00273 * (*fInvHMatrixB)(ivar,jvar) );
00274 }
00275 }
00276
00277
00278 if (chi2 < 0) Log() << kFATAL << "<GetChi2> negative chi2: " << chi2 << Endl;
00279
00280 return chi2;
00281 }
00282
00283
00284 Double_t TMVA::MethodHMatrix::GetChi2( Types::ESBType type ) const
00285 {
00286
00287
00288 const Event * ev = GetEvent();
00289
00290
00291 UInt_t ivar,jvar;
00292 vector<Double_t> val( GetNvar() );
00293 for (ivar=0; ivar<GetNvar(); ivar++) val[ivar] = ev->GetValue( ivar );
00294
00295 Double_t chi2 = 0;
00296 for (ivar=0; ivar<GetNvar(); ivar++) {
00297 for (jvar=0; jvar<GetNvar(); jvar++) {
00298 if (type == Types::kSignal)
00299 chi2 += ( (val[ivar] - (*fVecMeanS)(ivar))*(val[jvar] - (*fVecMeanS)(jvar))
00300 * (*fInvHMatrixS)(ivar,jvar) );
00301 else
00302 chi2 += ( (val[ivar] - (*fVecMeanB)(ivar))*(val[jvar] - (*fVecMeanB)(jvar))
00303 * (*fInvHMatrixB)(ivar,jvar) );
00304 }
00305 }
00306
00307
00308 if (chi2 < 0) Log() << kFATAL << "<GetChi2> negative chi2: " << chi2 << Endl;
00309
00310 return chi2;
00311 }
00312
00313
00314 void TMVA::MethodHMatrix::AddWeightsXMLTo( void* parent ) const {
00315 void* wght = gTools().AddChild(parent, "Weights");
00316 gTools().WriteTVectorDToXML(wght,"VecMeanS",fVecMeanS);
00317 gTools().WriteTVectorDToXML(wght,"VecMeanB", fVecMeanB);
00318 gTools().WriteTMatrixDToXML(wght,"InvHMatS",fInvHMatrixS);
00319 gTools().WriteTMatrixDToXML(wght,"InvHMatB",fInvHMatrixB);
00320
00321 }
00322
00323 void TMVA::MethodHMatrix::ReadWeightsFromXML( void* wghtnode ){
00324 void* descnode = gTools().GetChild(wghtnode);
00325 gTools().ReadTVectorDFromXML(descnode,"VecMeanS",fVecMeanS);
00326 descnode = gTools().GetNextChild(descnode);
00327 gTools().ReadTVectorDFromXML(descnode,"VecMeanB", fVecMeanB);
00328 descnode = gTools().GetNextChild(descnode);
00329 gTools().ReadTMatrixDFromXML(descnode,"InvHMatS",fInvHMatrixS);
00330 descnode = gTools().GetNextChild(descnode);
00331 gTools().ReadTMatrixDFromXML(descnode,"InvHMatB",fInvHMatrixB);
00332 }
00333
00334
00335 void TMVA::MethodHMatrix::ReadWeightsFromStream( istream& istr )
00336 {
00337
00338
00339
00340 UInt_t ivar,jvar;
00341 TString var, dummy;
00342 istr >> dummy;
00343
00344
00345
00346 for (ivar=0; ivar<GetNvar(); ivar++)
00347 istr >> (*fVecMeanS)(ivar) >> (*fVecMeanB)(ivar);
00348
00349
00350 for (ivar=0; ivar<GetNvar(); ivar++)
00351 for (jvar=0; jvar<GetNvar(); jvar++)
00352 istr >> (*fInvHMatrixS)(ivar,jvar);
00353
00354
00355 for (ivar=0; ivar<GetNvar(); ivar++)
00356 for (jvar=0; jvar<GetNvar(); jvar++)
00357 istr >> (*fInvHMatrixB)(ivar,jvar);
00358 }
00359
00360
00361 void TMVA::MethodHMatrix::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00362 {
00363
00364 fout << " // arrays of input evt vs. variable " << endl;
00365 fout << " double fInvHMatrixS[" << GetNvar() << "][" << GetNvar() << "]; // inverse H-matrix (signal)" << endl;
00366 fout << " double fInvHMatrixB[" << GetNvar() << "][" << GetNvar() << "]; // inverse H-matrix (background)" << endl;
00367 fout << " double fVecMeanS[" << GetNvar() << "]; // vector of mean values (signal)" << endl;
00368 fout << " double fVecMeanB[" << GetNvar() << "]; // vector of mean values (background)" << endl;
00369 fout << " " << endl;
00370 fout << " double GetChi2( const std::vector<double>& inputValues, int type ) const;" << endl;
00371 fout << "};" << endl;
00372 fout << " " << endl;
00373 fout << "void " << className << "::Initialize() " << endl;
00374 fout << "{" << endl;
00375 fout << " // init vectors with mean values" << endl;
00376 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00377 fout << " fVecMeanS[" << ivar << "] = " << (*fVecMeanS)(ivar) << ";" << endl;
00378 fout << " fVecMeanB[" << ivar << "] = " << (*fVecMeanB)(ivar) << ";" << endl;
00379 }
00380 fout << " " << endl;
00381 fout << " // init H-matrices" << endl;
00382 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00383 for (UInt_t jvar=0; jvar<GetNvar(); jvar++) {
00384 fout << " fInvHMatrixS[" << ivar << "][" << jvar << "] = "
00385 << (*fInvHMatrixS)(ivar,jvar) << ";" << endl;
00386 fout << " fInvHMatrixB[" << ivar << "][" << jvar << "] = "
00387 << (*fInvHMatrixB)(ivar,jvar) << ";" << endl;
00388 }
00389 }
00390 fout << "}" << endl;
00391 fout << " " << endl;
00392 fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << endl;
00393 fout << "{" << endl;
00394 fout << " // returns the H-matrix signal estimator" << endl;
00395 fout << " double s = GetChi2( inputValues, " << Types::kSignal << " );" << endl;
00396 fout << " double b = GetChi2( inputValues, " << Types::kBackground << " );" << endl;
00397 fout << " " << endl;
00398 fout << " if (s+b <= 0) std::cout << \"Problem in class " << className << "::GetMvaValue__: s+b = \"" << endl;
00399 fout << " << s+b << \" <= 0 \" << std::endl;" << endl;
00400 fout << " " << endl;
00401 fout << " return (b - s)/(s + b);" << endl;
00402 fout << "}" << endl;
00403 fout << " " << endl;
00404 fout << "inline double " << className << "::GetChi2( const std::vector<double>& inputValues, int type ) const" << endl;
00405 fout << "{" << endl;
00406 fout << " // compute chi2-estimator for event according to type (signal/background)" << endl;
00407 fout << " " << endl;
00408 fout << " size_t ivar,jvar;" << endl;
00409 fout << " double chi2 = 0;" << endl;
00410 fout << " for (ivar=0; ivar<GetNvar(); ivar++) {" << endl;
00411 fout << " for (jvar=0; jvar<GetNvar(); jvar++) {" << endl;
00412 fout << " if (type == " << Types::kSignal << ") " << endl;
00413 fout << " chi2 += ( (inputValues[ivar] - fVecMeanS[ivar])*(inputValues[jvar] - fVecMeanS[jvar])" << endl;
00414 fout << " * fInvHMatrixS[ivar][jvar] );" << endl;
00415 fout << " else" << endl;
00416 fout << " chi2 += ( (inputValues[ivar] - fVecMeanB[ivar])*(inputValues[jvar] - fVecMeanB[jvar])" << endl;
00417 fout << " * fInvHMatrixB[ivar][jvar] );" << endl;
00418 fout << " }" << endl;
00419 fout << " } // loop over variables " << endl;
00420 fout << " " << endl;
00421 fout << " // sanity check" << endl;
00422 fout << " if (chi2 < 0) std::cout << \"Problem in class " << className << "::GetChi2: chi2 = \"" << endl;
00423 fout << " << chi2 << \" < 0 \" << std::endl;" << endl;
00424 fout << " " << endl;
00425 fout << " return chi2;" << endl;
00426 fout << "}" << endl;
00427 fout << " " << endl;
00428 fout << "// Clean up" << endl;
00429 fout << "inline void " << className << "::Clear() " << endl;
00430 fout << "{" << endl;
00431 fout << " // nothing to clear" << endl;
00432 fout << "}" << endl;
00433 }
00434
00435
00436 void TMVA::MethodHMatrix::GetHelpMessage() const
00437 {
00438
00439
00440
00441
00442 Log() << Endl;
00443 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00444 Log() << Endl;
00445 Log() << "The H-Matrix classifier discriminates one class (signal) of a feature" << Endl;
00446 Log() << "vector from another (background). The correlated elements of the" << Endl;
00447 Log() << "vector are assumed to be Gaussian distributed, and the inverse of" << Endl;
00448 Log() << "the covariance matrix is the H-Matrix. A multivariate chi-squared" << Endl;
00449 Log() << "estimator is built that exploits differences in the mean values of" << Endl;
00450 Log() << "the vector elements between the two classes for the purpose of" << Endl;
00451 Log() << "discrimination." << Endl;
00452 Log() << Endl;
00453 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
00454 Log() << Endl;
00455 Log() << "The TMVA implementation of the H-Matrix classifier has been shown" << Endl;
00456 Log() << "to underperform in comparison with the corresponding Fisher discriminant," << Endl;
00457 Log() << "when using similar assumptions and complexity. Its use is therefore" << Endl;
00458 Log() << "depreciated. Only in cases where the background model is strongly" << Endl;
00459 Log() << "non-Gaussian, H-Matrix may perform better than Fisher. In such" << Endl;
00460 Log() << "occurrences the user is advised to employ non-linear classifiers. " << Endl;
00461 Log() << Endl;
00462 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
00463 Log() << Endl;
00464 Log() << "None" << Endl;
00465 }