00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048 #include "TMVA/ClassifierFactory.h"
00049 #include "TMVA/MethodCommittee.h"
00050 #include "TMVA/Tools.h"
00051 #include "TMVA/Timer.h"
00052 #include "Riostream.h"
00053 #include "TMath.h"
00054 #include "TRandom3.h"
00055 #include <algorithm>
00056 #include "TObjString.h"
00057 #include "TDirectory.h"
00058 #include "TMVA/Ranking.h"
00059 #include "TMVA/IMethod.h"
00060
00061 using std::vector;
00062
00063 REGISTER_METHOD(Committee)
00064
00065 ClassImp(TMVA::MethodCommittee)
00066
00067
00068 TMVA::MethodCommittee::MethodCommittee( const TString& jobName,
00069 const TString& methodTitle,
00070 DataSetInfo& dsi,
00071 const TString& theOption,
00072 TDirectory* theTargetDir ) :
00073 TMVA::MethodBase( jobName, Types::kCommittee, methodTitle, dsi, theOption, theTargetDir ),
00074 fNMembers(100),
00075 fBoostType("AdaBoost"),
00076 fMemberType(Types::kMaxMethod),
00077 fUseMemberDecision(kFALSE),
00078 fUseWeightedMembers(kFALSE),
00079 fITree(0),
00080 fBoostFactor(0),
00081 fErrorFraction(0),
00082 fNnodes(0)
00083 {
00084
00085 }
00086
00087
00088 TMVA::MethodCommittee::MethodCommittee( DataSetInfo& theData,
00089 const TString& theWeightFile,
00090 TDirectory* theTargetDir ) :
00091 TMVA::MethodBase( Types::kCommittee, theData, theWeightFile, theTargetDir ),
00092 fNMembers(100),
00093 fBoostType("AdaBoost"),
00094 fMemberType(Types::kMaxMethod),
00095 fUseMemberDecision(kFALSE),
00096 fUseWeightedMembers(kFALSE),
00097 fITree(0),
00098 fBoostFactor(0),
00099 fErrorFraction(0),
00100 fNnodes(0)
00101 {
00102
00103
00104
00105
00106 }
00107
00108
00109 Bool_t TMVA::MethodCommittee::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets )
00110 {
00111
00112 if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
00113 if( type == Types::kRegression && numberTargets == 1 ) return kTRUE;
00114 return kFALSE;
00115 }
00116
00117
00118 void TMVA::MethodCommittee::DeclareOptions()
00119 {
00120
00121
00122
00123
00124
00125
00126
00127
00128
00129
00130 DeclareOptionRef(fNMembers, "NMembers", "number of members in the committee");
00131 DeclareOptionRef(fUseMemberDecision=kFALSE, "UseMemberDecision", "use binary information from IsSignal");
00132 DeclareOptionRef(fUseWeightedMembers=kTRUE, "UseWeightedMembers", "use weighted trees or simple average in classification from the forest");
00133
00134 DeclareOptionRef(fBoostType, "BoostType", "boosting type");
00135 AddPreDefVal(TString("AdaBoost"));
00136 AddPreDefVal(TString("Bagging"));
00137 }
00138
00139
00140 void TMVA::MethodCommittee::ProcessOptions()
00141 {
00142
00143
00144
00145 fBoostFactorHist = new TH1F("fBoostFactor","Ada Boost weights",100,1,100);
00146 fErrFractHist = new TH2F("fErrFractHist","error fraction vs tree number",
00147 fNMembers,0,fNMembers,50,0,0.5);
00148 fMonitorNtuple = new TTree("fMonitorNtuple","Committee variables");
00149 fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
00150 fMonitorNtuple->Branch("boostFactor",&fBoostFactor,"boostFactor/D");
00151 fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
00152 }
00153
00154
00155 void TMVA::MethodCommittee::Init( void )
00156 {
00157
00158
00159 fNMembers = 100;
00160 fBoostType = "AdaBoost";
00161
00162 fCommittee.clear();
00163 fBoostWeights.clear();
00164 }
00165
00166
00167 TMVA::MethodCommittee::~MethodCommittee( void )
00168 {
00169
00170 for (UInt_t i=0; i<GetCommittee().size(); i++) delete fCommittee[i];
00171 fCommittee.clear();
00172 }
00173
00174
00175 void TMVA::MethodCommittee::WriteStateToFile() const
00176 {
00177
00178
00179
00180 TString fname(GetWeightFileName());
00181 Log() << kINFO << "creating weight file: " << fname << Endl;
00182
00183 std::ofstream* fout = new std::ofstream( fname );
00184 if (!fout->good()) {
00185 Log() << kFATAL << "<WriteStateToFile> "
00186 << "unable to open output weight file: " << fname << endl;
00187 }
00188
00189 WriteStateToStream( *fout );
00190 }
00191
00192
00193
00194 void TMVA::MethodCommittee::Train( void )
00195 {
00196
00197
00198 Log() << kINFO << "will train "<< fNMembers << " committee members ... patience please" << Endl;
00199
00200 Timer timer( fNMembers, GetName() );
00201 for (UInt_t imember=0; imember<fNMembers; imember++){
00202 timer.DrawProgressBar( imember );
00203
00204 IMethod* method = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( fMemberType )),
00205 GetJobName(),
00206 GetMethodName(),
00207 DataInfo(),
00208 fMemberOption );
00209
00210
00211
00212
00213 method->Train();
00214
00215 GetBoostWeights().push_back( this->Boost( dynamic_cast<MethodBase*>(method), imember ) );
00216
00217 GetCommittee().push_back( method );
00218
00219 fMonitorNtuple->Fill();
00220 }
00221
00222
00223 Log() << kINFO << "elapsed time: " << timer.GetElapsedTime()
00224 << " " << Endl;
00225 }
00226
00227
00228 Double_t TMVA::MethodCommittee::Boost( TMVA::MethodBase* method, UInt_t imember )
00229 {
00230
00231
00232 if(!method)
00233 return 0;
00234
00235 if (fBoostType=="AdaBoost") return this->AdaBoost( method );
00236 else if (fBoostType=="Bagging") return this->Bagging( imember );
00237 else {
00238 Log() << kINFO << GetOptions() << Endl;
00239 Log() << kFATAL << "<Boost> unknown boost option called" << Endl;
00240 }
00241 return 1.0;
00242 }
00243
00244
00245 Double_t TMVA::MethodCommittee::AdaBoost( TMVA::MethodBase* method )
00246 {
00247
00248
00249
00250
00251
00252
00253
00254
00255
00256
00257 Double_t adaBoostBeta = 1.;
00258
00259
00260 if (Data()->GetNTrainingEvents()) Log() << kFATAL << "<AdaBoost> Data().TrainingTree() is zero pointer" << Endl;
00261
00262 Double_t err=0, sumw=0, sumwfalse=0, count=0;
00263 vector<Char_t> correctSelected;
00264
00265
00266 MethodBase* mbase = (MethodBase*)method;
00267 for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00268
00269 Event* ev = Data()->GetEvent(ievt);
00270
00271
00272 sumw += ev->GetBoostWeight();
00273
00274
00275 Bool_t isSignalType = mbase->IsSignalLike();
00276
00277
00278 if (isSignalType == DataInfo().IsSignal(ev))
00279 correctSelected.push_back( kTRUE );
00280 else {
00281 sumwfalse += ev->GetBoostWeight();
00282 count += 1;
00283 correctSelected.push_back( kFALSE );
00284 }
00285 }
00286
00287 if (0 == sumw) {
00288 Log() << kFATAL << "<AdaBoost> fatal error sum of event boostweights is zero" << Endl;
00289 }
00290
00291
00292 err = sumwfalse/sumw;
00293
00294 Double_t newSumw=0;
00295 Int_t i=0;
00296 Double_t boostFactor = 1;
00297 if (err>0){
00298 if (adaBoostBeta == 1){
00299 boostFactor = (1-err)/err ;
00300 }
00301 else {
00302 boostFactor = TMath::Power((1-err)/err,adaBoostBeta) ;
00303 }
00304 }
00305 else {
00306 boostFactor = 1000;
00307 }
00308
00309
00310 for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00311
00312 Event *ev = Data()->GetEvent(ievt);
00313
00314
00315 if (!correctSelected[ievt]) ev->SetBoostWeight( ev->GetBoostWeight() * boostFactor);
00316
00317 newSumw += ev->GetBoostWeight();
00318 i++;
00319 }
00320
00321
00322 for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00323 Event *ev = Data()->GetEvent(ievt);
00324 ev->SetBoostWeight( ev->GetBoostWeight() * sumw / newSumw );
00325 }
00326
00327 fBoostFactorHist->Fill(boostFactor);
00328 fErrFractHist->Fill(GetCommittee().size(),err);
00329
00330
00331 fBoostFactor = boostFactor;
00332 fErrorFraction = err;
00333
00334
00335 return TMath::Log(boostFactor);
00336 }
00337
00338
00339 Double_t TMVA::MethodCommittee::Bagging( UInt_t imember )
00340 {
00341
00342
00343 Double_t newSumw = 0;
00344 TRandom3* trandom = new TRandom3( imember );
00345
00346
00347 for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00348 Event* ev = Data()->GetEvent(ievt);
00349
00350
00351 Double_t newWeight = trandom->Rndm();
00352 ev->SetBoostWeight( newWeight );
00353 newSumw += newWeight;
00354 }
00355
00356
00357 for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00358 Event* ev = Data()->GetEvent(ievt);
00359 ev->SetBoostWeight( ev->GetBoostWeight() * Data()->GetNTrainingEvents() / newSumw );
00360 }
00361
00362 delete trandom;
00363
00364 return 1.0;
00365 }
00366
00367
00368 void TMVA::MethodCommittee::AddWeightsXMLTo( void* ) const {
00369 Log() << kFATAL << "Please implement writing of weights as XML" << Endl;
00370 }
00371
00372
00373 void TMVA::MethodCommittee::ReadWeightsFromStream( istream& istr )
00374 {
00375
00376
00377
00378 std::vector<IMethod*>::iterator member = GetCommittee().begin();
00379 for (; member != GetCommittee().end(); member++) delete *member;
00380
00381 GetCommittee().clear();
00382 GetBoostWeights().clear();
00383
00384 TString dummy;
00385 UInt_t imember;
00386 Double_t boostWeight;
00387
00388 DataSetInfo & dsi = DataInfo();
00389
00390
00391 for (UInt_t i=0; i<fNMembers; i++) {
00392
00393 istr >> dummy >> dummy >> dummy >> imember;
00394 istr >> dummy >> dummy >> boostWeight;
00395
00396 if (imember != i) {
00397 Log() << kFATAL << "<ReadWeightsFromStream> fatal error while reading Weight file \n "
00398 << ": mismatch imember: " << imember << " != i: " << i << Endl;
00399 }
00400
00401
00402 IMethod* method = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( fMemberType )), dsi, "" );
00403
00404
00405 MethodBase* m = dynamic_cast<MethodBase*>(method);
00406 if(m)
00407 m->ReadStateFromStream(istr);
00408 GetCommittee().push_back(method);
00409 GetBoostWeights().push_back(boostWeight);
00410 }
00411 }
00412
00413
00414 Double_t TMVA::MethodCommittee::GetMvaValue( Double_t* err, Double_t* errUpper )
00415 {
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425 NoErrorCalc(err, errUpper);
00426
00427 Double_t myMVA = 0;
00428 Double_t norm = 0;
00429 for (UInt_t itree=0; itree<GetCommittee().size(); itree++) {
00430
00431 MethodBase* m = dynamic_cast<MethodBase*>(GetCommittee()[itree]);
00432 if(m==0) continue;
00433
00434 Double_t tmpMVA = ( fUseMemberDecision ? ( m->IsSignalLike() ? 1.0 : -1.0 )
00435 : GetCommittee()[itree]->GetMvaValue() );
00436
00437 if (fUseWeightedMembers){
00438 myMVA += GetBoostWeights()[itree] * tmpMVA;
00439 norm += GetBoostWeights()[itree];
00440 }
00441 else {
00442 myMVA += tmpMVA;
00443 norm += 1;
00444 }
00445 }
00446 return (norm != 0) ? myMVA /= Double_t(norm) : -999;
00447 }
00448
00449
00450 void TMVA::MethodCommittee::WriteMonitoringHistosToFile( void ) const
00451 {
00452
00453
00454 Log() << kINFO << "Write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
00455
00456 fBoostFactorHist->Write();
00457 fErrFractHist->Write();
00458 fMonitorNtuple->Write();
00459
00460 BaseDir()->cd();
00461 }
00462
00463
00464
00465 vector< Double_t > TMVA::MethodCommittee::GetVariableImportance()
00466 {
00467
00468
00469
00470
00471
00472 fVariableImportance.resize(GetNvar());
00473
00474
00475
00476
00477
00478
00479
00480
00481
00482
00483 return fVariableImportance;
00484 }
00485
00486
00487 Double_t TMVA::MethodCommittee::GetVariableImportance(UInt_t ivar)
00488 {
00489
00490 vector<Double_t> relativeImportance = this->GetVariableImportance();
00491 if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
00492 else Log() << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
00493
00494 return -1;
00495 }
00496
00497
00498 const TMVA::Ranking* TMVA::MethodCommittee::CreateRanking()
00499 {
00500
00501
00502
00503 fRanking = new Ranking( GetName(), "Variable Importance" );
00504 vector< Double_t> importance(this->GetVariableImportance());
00505
00506 for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00507 fRanking->AddRank( Rank( GetInputLabel(ivar), importance[ivar] ) );
00508 }
00509
00510 return fRanking;
00511 }
00512
00513
00514 void TMVA::MethodCommittee::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00515 {
00516
00517 fout << " // not implemented for class: \"" << className << "\"" << endl;
00518 fout << "};" << endl;
00519 }
00520
00521
00522 void TMVA::MethodCommittee::GetHelpMessage() const
00523 {
00524
00525
00526
00527
00528 Log() << Endl;
00529 Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00530 Log() << Endl;
00531 Log() << "<None>" << Endl;
00532 Log() << Endl;
00533 Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
00534 Log() << Endl;
00535 Log() << "<None>" << Endl;
00536 Log() << Endl;
00537 Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
00538 Log() << Endl;
00539 Log() << "<None>" << Endl;
00540 }