00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091
00092
00093
00094
00095
00096
00097
00098
00099
00100
00101
00102
00103
00104 #include <iomanip>
00105 #include <cassert>
00106
00107 #include "TMath.h"
00108 #include "Riostream.h"
00109
00110 #include "TMVA/VariableTransformBase.h"
00111 #include "TMVA/MethodFisher.h"
00112 #include "TMVA/Tools.h"
00113 #include "TMatrix.h"
00114 #include "TMVA/Ranking.h"
00115 #include "TMVA/Types.h"
00116 #include "TMVA/ClassifierFactory.h"
00117
00118 REGISTER_METHOD(Fisher)
00119
00120 ClassImp(TMVA::MethodFisher);
00121
00122
00123 TMVA::MethodFisher::MethodFisher( const TString& jobName,
00124 const TString& methodTitle,
00125 DataSetInfo& dsi,
00126 const TString& theOption,
00127 TDirectory* theTargetDir ) :
00128 MethodBase( jobName, Types::kFisher, methodTitle, dsi, theOption, theTargetDir ),
00129 fMeanMatx ( 0 ),
00130 fTheMethod ( "Fisher" ),
00131 fFisherMethod ( kFisher ),
00132 fBetw ( 0 ),
00133 fWith ( 0 ),
00134 fCov ( 0 ),
00135 fSumOfWeightsS( 0 ),
00136 fSumOfWeightsB( 0 ),
00137 fDiscrimPow ( 0 ),
00138 fFisherCoeff ( 0 ),
00139 fF0 ( 0 )
00140 {
00141
00142 }
00143
00144
00145 TMVA::MethodFisher::MethodFisher( DataSetInfo& dsi,
00146 const TString& theWeightFile,
00147 TDirectory* theTargetDir ) :
00148 MethodBase( Types::kFisher, dsi, theWeightFile, theTargetDir ),
00149 fMeanMatx ( 0 ),
00150 fTheMethod ( "Fisher" ),
00151 fFisherMethod ( kFisher ),
00152 fBetw ( 0 ),
00153 fWith ( 0 ),
00154 fCov ( 0 ),
00155 fSumOfWeightsS( 0 ),
00156 fSumOfWeightsB( 0 ),
00157 fDiscrimPow ( 0 ),
00158 fFisherCoeff ( 0 ),
00159 fF0 ( 0 )
00160 {
00161
00162 }
00163
00164
00165 void TMVA::MethodFisher::Init( void )
00166 {
00167
00168
00169
00170 fFisherCoeff = new std::vector<Double_t>( GetNvar() );
00171
00172
00173 SetSignalReferenceCut( 0.0 );
00174
00175
00176 InitMatrices();
00177 }
00178
00179
00180 void TMVA::MethodFisher::DeclareOptions()
00181 {
00182
00183
00184
00185
00186
00187 DeclareOptionRef( fTheMethod = "Fisher", "Method", "Discrimination method" );
00188 AddPreDefVal(TString("Fisher"));
00189 AddPreDefVal(TString("Mahalanobis"));
00190 }
00191
00192
00193 void TMVA::MethodFisher::ProcessOptions()
00194 {
00195
00196 if (fTheMethod == "Fisher" ) fFisherMethod = kFisher;
00197 else fFisherMethod = kMahalanobis;
00198
00199
00200 InitMatrices();
00201 }
00202
00203
00204 TMVA::MethodFisher::~MethodFisher( void )
00205 {
00206
00207 if (fBetw ) { delete fBetw; fBetw = 0; }
00208 if (fWith ) { delete fWith; fWith = 0; }
00209 if (fCov ) { delete fCov; fCov = 0; }
00210 if (fDiscrimPow ) { delete fDiscrimPow; fDiscrimPow = 0; }
00211 if (fFisherCoeff) { delete fFisherCoeff; fFisherCoeff = 0; }
00212 }
00213
00214
00215 Bool_t TMVA::MethodFisher::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
00216 {
00217
00218 if (type == Types::kClassification && numberClasses == 2) return kTRUE;
00219 return kFALSE;
00220 }
00221
00222
00223 void TMVA::MethodFisher::Train( void )
00224 {
00225
00226
00227
00228 GetMean();
00229
00230
00231 GetCov_WithinClass();
00232
00233
00234 GetCov_BetweenClass();
00235
00236
00237 GetCov_Full();
00238
00239
00240
00241
00242 GetFisherCoeff();
00243
00244
00245 GetDiscrimPower();
00246
00247
00248 PrintCoefficients();
00249 }
00250
00251
00252 Double_t TMVA::MethodFisher::GetMvaValue( Double_t* err, Double_t* errUpper )
00253 {
00254
00255 const Event * ev = GetEvent();
00256 Double_t result = fF0;
00257 for (UInt_t ivar=0; ivar<GetNvar(); ivar++)
00258 result += (*fFisherCoeff)[ivar]*ev->GetValue(ivar);
00259
00260
00261 NoErrorCalc(err, errUpper);
00262
00263 return result;
00264
00265 }
00266
00267
00268 void TMVA::MethodFisher::InitMatrices( void )
00269 {
00270
00271
00272
00273 fMeanMatx = new TMatrixD( GetNvar(), 3 );
00274
00275
00276 fBetw = new TMatrixD( GetNvar(), GetNvar() );
00277 fWith = new TMatrixD( GetNvar(), GetNvar() );
00278 fCov = new TMatrixD( GetNvar(), GetNvar() );
00279
00280
00281 fDiscrimPow = new std::vector<Double_t>( GetNvar() );
00282 }
00283
00284
00285 void TMVA::MethodFisher::GetMean( void )
00286 {
00287
00288
00289
00290 fSumOfWeightsS = 0;
00291 fSumOfWeightsB = 0;
00292
00293 const UInt_t nvar = DataInfo().GetNVariables();
00294
00295
00296 Double_t* sumS = new Double_t[nvar];
00297 Double_t* sumB = new Double_t[nvar];
00298 for (UInt_t ivar=0; ivar<nvar; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
00299
00300
00301 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
00302
00303
00304 const Event * ev = GetEvent(ievt);
00305
00306
00307 Double_t weight = GetTWeight(ev);
00308 if (DataInfo().IsSignal(ev)) fSumOfWeightsS += weight;
00309 else fSumOfWeightsB += weight;
00310
00311 Double_t* sum = DataInfo().IsSignal(ev) ? sumS : sumB;
00312
00313 for (UInt_t ivar=0; ivar<nvar; ivar++) sum[ivar] += ev->GetValue( ivar )*weight;
00314 }
00315
00316 for (UInt_t ivar=0; ivar<nvar; ivar++) {
00317 (*fMeanMatx)( ivar, 2 ) = sumS[ivar];
00318 (*fMeanMatx)( ivar, 0 ) = sumS[ivar]/fSumOfWeightsS;
00319
00320 (*fMeanMatx)( ivar, 2 ) += sumB[ivar];
00321 (*fMeanMatx)( ivar, 1 ) = sumB[ivar]/fSumOfWeightsB;
00322
00323
00324 (*fMeanMatx)( ivar, 2 ) /= (fSumOfWeightsS + fSumOfWeightsB);
00325 }
00326 delete [] sumS;
00327 delete [] sumB;
00328 }
00329
00330
00331 void TMVA::MethodFisher::GetCov_WithinClass( void )
00332 {
00333
00334
00335
00336
00337 assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0 );
00338
00339
00340
00341
00342 const Int_t nvar = GetNvar();
00343 const Int_t nvar2 = nvar*nvar;
00344 Double_t *sumSig = new Double_t[nvar2];
00345 Double_t *sumBgd = new Double_t[nvar2];
00346 Double_t *xval = new Double_t[nvar];
00347 memset(sumSig,0,nvar2*sizeof(Double_t));
00348 memset(sumBgd,0,nvar2*sizeof(Double_t));
00349
00350
00351 for (Int_t ievt=0; ievt<Data()->GetNEvents(); ievt++) {
00352
00353
00354 const Event* ev = GetEvent(ievt);
00355
00356 Double_t weight = GetTWeight(ev);
00357
00358 for (Int_t x=0; x<nvar; x++) xval[x] = ev->GetValue( x );
00359 Int_t k=0;
00360 for (Int_t x=0; x<nvar; x++) {
00361 for (Int_t y=0; y<nvar; y++) {
00362 Double_t v = ( (xval[x] - (*fMeanMatx)(x, 0))*(xval[y] - (*fMeanMatx)(y, 0)) )*weight;
00363 if (DataInfo().IsSignal(ev)) sumSig[k] += v;
00364 else sumBgd[k] += v;
00365 k++;
00366 }
00367 }
00368 }
00369 Int_t k=0;
00370 for (Int_t x=0; x<nvar; x++) {
00371 for (Int_t y=0; y<nvar; y++) {
00372 (*fWith)(x, y) = (sumSig[k] + sumBgd[k])/(fSumOfWeightsS + fSumOfWeightsB);
00373 k++;
00374 }
00375 }
00376
00377 delete [] sumSig;
00378 delete [] sumBgd;
00379 delete [] xval;
00380 }
00381
00382
00383 void TMVA::MethodFisher::GetCov_BetweenClass( void )
00384 {
00385
00386
00387
00388
00389
00390 assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0);
00391
00392 Double_t prodSig, prodBgd;
00393
00394 for (UInt_t x=0; x<GetNvar(); x++) {
00395 for (UInt_t y=0; y<GetNvar(); y++) {
00396
00397 prodSig = ( ((*fMeanMatx)(x, 0) - (*fMeanMatx)(x, 2))*
00398 ((*fMeanMatx)(y, 0) - (*fMeanMatx)(y, 2)) );
00399 prodBgd = ( ((*fMeanMatx)(x, 1) - (*fMeanMatx)(x, 2))*
00400 ((*fMeanMatx)(y, 1) - (*fMeanMatx)(y, 2)) );
00401
00402 (*fBetw)(x, y) = (fSumOfWeightsS*prodSig + fSumOfWeightsB*prodBgd) / (fSumOfWeightsS + fSumOfWeightsB);
00403 }
00404 }
00405 }
00406
00407
00408 void TMVA::MethodFisher::GetCov_Full( void )
00409 {
00410
00411 for (UInt_t x=0; x<GetNvar(); x++)
00412 for (UInt_t y=0; y<GetNvar(); y++)
00413 (*fCov)(x, y) = (*fWith)(x, y) + (*fBetw)(x, y);
00414 }
00415
00416
00417 void TMVA::MethodFisher::GetFisherCoeff( void )
00418 {
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429 assert( fSumOfWeightsS > 0 && fSumOfWeightsB > 0);
00430
00431
00432 TMatrixD* theMat = 0;
00433 switch (GetFisherMethod()) {
00434 case kFisher:
00435 theMat = fWith;
00436 break;
00437 case kMahalanobis:
00438 theMat = fCov;
00439 break;
00440 default:
00441 Log() << kFATAL << "<GetFisherCoeff> undefined method" << GetFisherMethod() << Endl;
00442 }
00443
00444 TMatrixD invCov( *theMat );
00445 if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
00446 Log() << kWARNING << "<GetFisherCoeff> matrix is almost singular with deterninant="
00447 << TMath::Abs(invCov.Determinant())
00448 << " did you use the variables that are linear combinations or highly correlated?"
00449 << Endl;
00450 }
00451 if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
00452 Log() << kFATAL << "<GetFisherCoeff> matrix is singular with determinant="
00453 << TMath::Abs(invCov.Determinant())
00454 << " did you use the variables that are linear combinations?"
00455 << Endl;
00456 }
00457
00458 invCov.Invert();
00459
00460
00461 Double_t xfact = TMath::Sqrt( fSumOfWeightsS*fSumOfWeightsB ) / (fSumOfWeightsS + fSumOfWeightsB);
00462
00463
00464 std::vector<Double_t> diffMeans( GetNvar() );
00465 UInt_t ivar, jvar;
00466 for (ivar=0; ivar<GetNvar(); ivar++) {
00467 (*fFisherCoeff)[ivar] = 0;
00468
00469 for (jvar=0; jvar<GetNvar(); jvar++) {
00470 Double_t d = (*fMeanMatx)(jvar, 0) - (*fMeanMatx)(jvar, 1);
00471 (*fFisherCoeff)[ivar] += invCov(ivar, jvar)*d;
00472 }
00473
00474
00475 (*fFisherCoeff)[ivar] *= xfact;
00476 }
00477
00478
00479 fF0 = 0.0;
00480 for (ivar=0; ivar<GetNvar(); ivar++){
00481 fF0 += (*fFisherCoeff)[ivar]*((*fMeanMatx)(ivar, 0) + (*fMeanMatx)(ivar, 1));
00482 }
00483 fF0 /= -2.0;
00484 }
00485
00486
00487 void TMVA::MethodFisher::GetDiscrimPower( void )
00488 {
00489
00490
00491
00492
00493
00494
00495 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00496 if ((*fCov)(ivar, ivar) != 0)
00497 (*fDiscrimPow)[ivar] = (*fBetw)(ivar, ivar)/(*fCov)(ivar, ivar);
00498 else
00499 (*fDiscrimPow)[ivar] = 0;
00500 }
00501 }
00502
00503
00504 const TMVA::Ranking* TMVA::MethodFisher::CreateRanking()
00505 {
00506
00507
00508
00509 fRanking = new Ranking( GetName(), "Discr. power" );
00510
00511 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00512 fRanking->AddRank( Rank( GetInputLabel(ivar), (*fDiscrimPow)[ivar] ) );
00513 }
00514
00515 return fRanking;
00516 }
00517
00518
00519 void TMVA::MethodFisher::PrintCoefficients( void )
00520 {
00521
00522
00523 Log() << kINFO << "Results for Fisher coefficients:" << Endl;
00524
00525 if (GetTransformationHandler().GetTransformationList().GetSize() != 0) {
00526 Log() << kINFO << "NOTE: The coefficients must be applied to TRANFORMED variables" << Endl;
00527 Log() << kINFO << " List of the transformation: " << Endl;
00528 TListIter trIt(&GetTransformationHandler().GetTransformationList());
00529 while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) {
00530 Log() << kINFO << " -- " << trf->GetName() << Endl;
00531 }
00532 }
00533 std::vector<TString> vars;
00534 std::vector<Double_t> coeffs;
00535 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00536 vars .push_back( GetInputLabel(ivar) );
00537 coeffs.push_back( (*fFisherCoeff)[ivar] );
00538 }
00539 vars .push_back( "(offset)" );
00540 coeffs.push_back( fF0 );
00541 TMVA::gTools().FormattedOutput( coeffs, vars, "Variable" , "Coefficient", Log() );
00542
00543 if (IsNormalised()) {
00544 Log() << kINFO << "NOTE: You have chosen to use the \"Normalise\" booking option. Hence, the" << Endl;
00545 Log() << kINFO << " coefficients must be applied to NORMALISED (') variables as follows:" << Endl;
00546 Int_t maxL = 0;
00547 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) if (GetInputLabel(ivar).Length() > maxL) maxL = GetInputLabel(ivar).Length();
00548
00549
00550 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00551 Log() << kINFO
00552 << setw(maxL+9) << TString("[") + GetInputLabel(ivar) + "]' = 2*("
00553 << setw(maxL+2) << TString("[") + GetInputLabel(ivar) + "]"
00554 << setw(3) << (GetXmin(ivar) > 0 ? " - " : " + ")
00555 << setw(6) << TMath::Abs(GetXmin(ivar)) << setw(3) << ")/"
00556 << setw(6) << (GetXmax(ivar) - GetXmin(ivar) )
00557 << setw(3) << " - 1"
00558 << Endl;
00559 }
00560 Log() << kINFO << "The TMVA Reader will properly account for this normalisation, but if the" << Endl;
00561 Log() << kINFO << "Fisher classifier is applied outside the Reader, the transformation must be" << Endl;
00562 Log() << kINFO << "implemented -- or the \"Normalise\" option is removed and Fisher retrained." << Endl;
00563 Log() << kINFO << Endl;
00564 }
00565 }
00566
00567
00568 void TMVA::MethodFisher::ReadWeightsFromStream( istream& istr )
00569 {
00570
00571 istr >> fF0;
00572 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) istr >> (*fFisherCoeff)[ivar];
00573 }
00574
00575
00576 void TMVA::MethodFisher::AddWeightsXMLTo( void* parent ) const
00577 {
00578
00579
00580 void* wght = gTools().AddChild(parent, "Weights");
00581 gTools().AddAttr( wght, "NCoeff", GetNvar()+1 );
00582 void* coeffxml = gTools().AddChild(wght, "Coefficient");
00583 gTools().AddAttr( coeffxml, "Index", 0 );
00584 gTools().AddAttr( coeffxml, "Value", fF0 );
00585 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00586 coeffxml = gTools().AddChild( wght, "Coefficient" );
00587 gTools().AddAttr( coeffxml, "Index", ivar+1 );
00588 gTools().AddAttr( coeffxml, "Value", (*fFisherCoeff)[ivar] );
00589 }
00590 }
00591
00592
00593 void TMVA::MethodFisher::ReadWeightsFromXML( void* wghtnode )
00594 {
00595
00596 UInt_t ncoeff, coeffidx;
00597 gTools().ReadAttr( wghtnode, "NCoeff", ncoeff );
00598 fFisherCoeff->resize(ncoeff-1);
00599
00600 void* ch = gTools().GetChild(wghtnode);
00601 Double_t coeff;
00602 while (ch) {
00603 gTools().ReadAttr( ch, "Index", coeffidx );
00604 gTools().ReadAttr( ch, "Value", coeff );
00605 if (coeffidx==0) fF0 = coeff;
00606 else (*fFisherCoeff)[coeffidx-1] = coeff;
00607 ch = gTools().GetNextChild(ch);
00608 }
00609 }
00610
00611
00612 void TMVA::MethodFisher::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00613 {
00614
00615 Int_t dp = fout.precision();
00616 fout << " double fFisher0;" << endl;
00617 fout << " std::vector<double> fFisherCoefficients;" << endl;
00618 fout << "};" << endl;
00619 fout << "" << endl;
00620 fout << "inline void " << className << "::Initialize() " << endl;
00621 fout << "{" << endl;
00622 fout << " fFisher0 = " << std::setprecision(12) << fF0 << ";" << endl;
00623 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00624 fout << " fFisherCoefficients.push_back( " << std::setprecision(12) << (*fFisherCoeff)[ivar] << " );" << endl;
00625 }
00626 fout << endl;
00627 fout << " // sanity check" << endl;
00628 fout << " if (fFisherCoefficients.size() != fNvars) {" << endl;
00629 fout << " std::cout << \"Problem in class \\\"\" << fClassName << \"\\\"::Initialize: mismatch in number of input values\"" << endl;
00630 fout << " << fFisherCoefficients.size() << \" != \" << fNvars << std::endl;" << endl;
00631 fout << " fStatusIsClean = false;" << endl;
00632 fout << " } " << endl;
00633 fout << "}" << endl;
00634 fout << endl;
00635 fout << "inline double " << className << "::GetMvaValue__( const std::vector<double>& inputValues ) const" << endl;
00636 fout << "{" << endl;
00637 fout << " double retval = fFisher0;" << endl;
00638 fout << " for (size_t ivar = 0; ivar < fNvars; ivar++) {" << endl;
00639 fout << " retval += fFisherCoefficients[ivar]*inputValues[ivar];" << endl;
00640 fout << " }" << endl;
00641 fout << endl;
00642 fout << " return retval;" << endl;
00643 fout << "}" << endl;
00644 fout << endl;
00645 fout << "// Clean up" << endl;
00646 fout << "inline void " << className << "::Clear() " << endl;
00647 fout << "{" << endl;
00648 fout << " // clear coefficients" << endl;
00649 fout << " fFisherCoefficients.clear(); " << endl;
00650 fout << "}" << endl;
00651 fout << std::setprecision(dp);
00652 }
00653
00654
00655 void TMVA::MethodFisher::GetHelpMessage() const
00656 {
00657
00658
00659
00660
00661 Log() << Endl;
00662 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00663 Log() << Endl;
00664 Log() << "Fisher discriminants select events by distinguishing the mean " << Endl;
00665 Log() << "values of the signal and background distributions in a trans- " << Endl;
00666 Log() << "formed variable space where linear correlations are removed." << Endl;
00667 Log() << Endl;
00668 Log() << " (More precisely: the \"linear discriminator\" determines" << Endl;
00669 Log() << " an axis in the (correlated) hyperspace of the input " << Endl;
00670 Log() << " variables such that, when projecting the output classes " << Endl;
00671 Log() << " (signal and background) upon this axis, they are pushed " << Endl;
00672 Log() << " as far as possible away from each other, while events" << Endl;
00673 Log() << " of a same class are confined in a close vicinity. The " << Endl;
00674 Log() << " linearity property of this classifier is reflected in the " << Endl;
00675 Log() << " metric with which \"far apart\" and \"close vicinity\" are " << Endl;
00676 Log() << " determined: the covariance matrix of the discriminating" << Endl;
00677 Log() << " variable space.)" << Endl;
00678 Log() << Endl;
00679 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
00680 Log() << Endl;
00681 Log() << "Optimal performance for Fisher discriminants is obtained for " << Endl;
00682 Log() << "linearly correlated Gaussian-distributed variables. Any deviation" << Endl;
00683 Log() << "from this ideal reduces the achievable separation power. In " << Endl;
00684 Log() << "particular, no discrimination at all is achieved for a variable" << Endl;
00685 Log() << "that has the same sample mean for signal and background, even if " << Endl;
00686 Log() << "the shapes of the distributions are very different. Thus, Fisher " << Endl;
00687 Log() << "discriminants often benefit from suitable transformations of the " << Endl;
00688 Log() << "input variables. For example, if a variable x in [-1,1] has a " << Endl;
00689 Log() << "a parabolic signal distributions, and a uniform background" << Endl;
00690 Log() << "distributions, their mean value is zero in both cases, leading " << Endl;
00691 Log() << "to no separation. The simple transformation x -> |x| renders this " << Endl;
00692 Log() << "variable powerful for the use in a Fisher discriminant." << Endl;
00693 Log() << Endl;
00694 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
00695 Log() << Endl;
00696 Log() << "<None>" << Endl;
00697 }