MethodCommittee.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodCommittee.cxx 36966 2010-11-26 09:50:13Z evt $ 
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : MethodCommittee                                                       *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation                                                            *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00016  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00017  *                                                                                *
00018  * Copyright (c) 2005:                                                            *
00019  *      CERN, Switzerland                                                         * 
00020  *      U. of Victoria, Canada                                                    * 
00021  *      MPI-K Heidelberg, Germany                                                 * 
00022  *      LAPP, Annecy, France                                                      *
00023  *                                                                                *
00024  * Redistribution and use in source and binary forms, with or without             *
00025  * modification, are permitted according to the terms listed in LICENSE           *
00026  * (http://tmva.sourceforge.net/LICENSE)                                          *
00027  **********************************************************************************/
00028 
00029 //_______________________________________________________________________
00030 //                                                                      
00031 // Boosting: 
00032 //
00033 // the idea behind the boosting is, that signal events from the training
00034 // sample, that end up in a background node (and vice versa) are given a
00035 // larger weight than events that are in the correct leave node. This
00036 // results in a re-weighed training event sample, with which then a new
00037 // decision tree can be developed. The boosting can be applied several
00038 // times (typically 100-500 times) and one ends up with a set of decision
00039 // trees (a forest).
00040 //
00041 // Bagging: 
00042 //
00043 // In this particular variant of the Boosted Decision Trees the boosting
00044 // is not done on the basis of previous training results, but by a simple
00045 // stochasitc re-sampling of the initial training event sample.
00046 //_______________________________________________________________________
00047 
00048 #include "TMVA/ClassifierFactory.h"
00049 #include "TMVA/MethodCommittee.h"
00050 #include "TMVA/Tools.h"
00051 #include "TMVA/Timer.h"
00052 #include "Riostream.h"
00053 #include "TMath.h"
00054 #include "TRandom3.h"
00055 #include <algorithm>
00056 #include "TObjString.h"
00057 #include "TDirectory.h"
00058 #include "TMVA/Ranking.h"
00059 #include "TMVA/IMethod.h"
00060 
00061 using std::vector;
00062 
00063 REGISTER_METHOD(Committee)
00064 
00065 ClassImp(TMVA::MethodCommittee)
00066  
00067 //_______________________________________________________________________
00068 TMVA::MethodCommittee::MethodCommittee( const TString& jobName,
00069                                         const TString& methodTitle,
00070                                         DataSetInfo& dsi, 
00071                                         const TString& theOption,
00072                                         TDirectory* theTargetDir ) :
00073    TMVA::MethodBase( jobName, Types::kCommittee, methodTitle, dsi, theOption, theTargetDir ),
00074    fNMembers(100),
00075    fBoostType("AdaBoost"),
00076    fMemberType(Types::kMaxMethod),
00077    fUseMemberDecision(kFALSE),
00078    fUseWeightedMembers(kFALSE),
00079    fITree(0),
00080    fBoostFactor(0),
00081    fErrorFraction(0),
00082    fNnodes(0)
00083 {
00084    // constructor
00085 }
00086 
00087 //_______________________________________________________________________
00088 TMVA::MethodCommittee::MethodCommittee( DataSetInfo& theData, 
00089                                         const TString& theWeightFile,  
00090                                         TDirectory* theTargetDir ) :
00091    TMVA::MethodBase( Types::kCommittee, theData, theWeightFile, theTargetDir ),
00092    fNMembers(100),
00093    fBoostType("AdaBoost"),
00094    fMemberType(Types::kMaxMethod),
00095    fUseMemberDecision(kFALSE),
00096    fUseWeightedMembers(kFALSE),
00097    fITree(0),
00098    fBoostFactor(0),
00099    fErrorFraction(0),
00100    fNnodes(0)
00101 {
00102    // constructor for calculating Committee-MVA using previously generatad decision trees
00103    // the result of the previous training (the decision trees) are read in via the
00104    // weightfile. Make sure the "theVariables" correspond to the ones used in 
00105    // creating the "weight"-file
00106 }
00107 
00108 //_______________________________________________________________________
00109 Bool_t TMVA::MethodCommittee::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets )
00110 {
00111    // FDA can handle classification with 2 classes and regression with one regression-target
00112    if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
00113    if( type == Types::kRegression && numberTargets == 1 ) return kTRUE;
00114    return kFALSE;
00115 }
00116 
00117 //_______________________________________________________________________
00118 void TMVA::MethodCommittee::DeclareOptions() 
00119 {
00120    // define the options (their key words) that can be set in the option string 
00121    // know options:
00122    // NMembers           <string>     number of members in the committee
00123    // UseMemberDecision  <bool>       use signal information from event (otherwise assume signal)
00124    // UseWeightedMembers <bool>       use weighted trees or simple average in classification from the forest
00125    //
00126    // BoostType          <string>     boosting type
00127    //    available values are:        AdaBoost  <default>
00128    //                                 Bagging
00129 
00130    DeclareOptionRef(fNMembers, "NMembers", "number of members in the committee");
00131    DeclareOptionRef(fUseMemberDecision=kFALSE, "UseMemberDecision", "use binary information from IsSignal");
00132    DeclareOptionRef(fUseWeightedMembers=kTRUE, "UseWeightedMembers", "use weighted trees or simple average in classification from the forest");
00133 
00134    DeclareOptionRef(fBoostType, "BoostType", "boosting type");
00135    AddPreDefVal(TString("AdaBoost"));
00136    AddPreDefVal(TString("Bagging"));
00137 }
00138 
00139 //_______________________________________________________________________
00140 void TMVA::MethodCommittee::ProcessOptions() 
00141 {
00142    // process user options
00143 
00144    // book monitoring histograms (currently for AdaBost, only)
00145    fBoostFactorHist = new TH1F("fBoostFactor","Ada Boost weights",100,1,100);
00146    fErrFractHist    = new TH2F("fErrFractHist","error fraction vs tree number",
00147                                fNMembers,0,fNMembers,50,0,0.5);
00148    fMonitorNtuple   = new TTree("fMonitorNtuple","Committee variables");
00149    fMonitorNtuple->Branch("iTree",&fITree,"iTree/I");
00150    fMonitorNtuple->Branch("boostFactor",&fBoostFactor,"boostFactor/D");
00151    fMonitorNtuple->Branch("errorFraction",&fErrorFraction,"errorFraction/D");
00152 }
00153 
00154 //_______________________________________________________________________
00155 void TMVA::MethodCommittee::Init( void )
00156 {
00157    // common initialisation with defaults for the Committee-Method
00158 
00159    fNMembers  = 100;
00160    fBoostType = "AdaBoost";   
00161 
00162    fCommittee.clear();
00163    fBoostWeights.clear();
00164 }
00165 
00166 //_______________________________________________________________________
00167 TMVA::MethodCommittee::~MethodCommittee( void )
00168 {
00169    //destructor
00170    for (UInt_t i=0; i<GetCommittee().size(); i++)   delete fCommittee[i];
00171    fCommittee.clear();
00172 }
00173 
00174 //_______________________________________________________________________
00175 void TMVA::MethodCommittee::WriteStateToFile() const
00176 { 
00177    // Function to write options and weights to file
00178 
00179    // get the filename
00180    TString fname(GetWeightFileName());
00181    Log() << kINFO << "creating weight file: " << fname << Endl;
00182    
00183    std::ofstream* fout = new std::ofstream( fname );
00184    if (!fout->good()) { // file not found --> Error
00185       Log() << kFATAL << "<WriteStateToFile> "
00186               << "unable to open output  weight file: " << fname << endl;
00187    }
00188    
00189    WriteStateToStream( *fout );
00190 }
00191 
00192 
00193 //_______________________________________________________________________
00194 void TMVA::MethodCommittee::Train( void )
00195 {  
00196    // training
00197 
00198    Log() << kINFO << "will train "<< fNMembers << " committee members ... patience please" << Endl;
00199 
00200    Timer timer( fNMembers, GetName() ); 
00201    for (UInt_t imember=0; imember<fNMembers; imember++){
00202       timer.DrawProgressBar( imember );
00203 
00204       IMethod* method = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( fMemberType )), 
00205                                                              GetJobName(),
00206                                                              GetMethodName(),
00207                                                              DataInfo(),
00208                                                              fMemberOption );
00209 
00210 
00211       
00212       // train each of the member methods
00213       method->Train();
00214 
00215       GetBoostWeights().push_back( this->Boost( dynamic_cast<MethodBase*>(method), imember ) );
00216 
00217       GetCommittee().push_back( method );
00218 
00219       fMonitorNtuple->Fill();
00220    }
00221 
00222    // get elapsed time
00223    Log() << kINFO << "elapsed time: " << timer.GetElapsedTime()    
00224            << "                              " << Endl;    
00225 }
00226 
00227 //_______________________________________________________________________
00228 Double_t TMVA::MethodCommittee::Boost( TMVA::MethodBase* method, UInt_t imember )
00229 {
00230    // apply the boosting alogrithim (the algorithm is selecte via the the "option" given
00231    // in the constructor. The return value is the boosting weight 
00232    if(!method)
00233       return 0;
00234    
00235    if      (fBoostType=="AdaBoost") return this->AdaBoost( method );
00236    else if (fBoostType=="Bagging")  return this->Bagging( imember );
00237    else {
00238       Log() << kINFO << GetOptions() << Endl;
00239       Log() << kFATAL << "<Boost> unknown boost option called" << Endl;
00240    }
00241    return 1.0;
00242 }
00243 
00244 //_______________________________________________________________________
00245 Double_t TMVA::MethodCommittee::AdaBoost( TMVA::MethodBase* method )
00246 {
00247    // the AdaBoost implementation.
00248    // a new training sample is generated by weighting 
00249    // events that are misclassified by the decision tree. The weight
00250    // applied is w = (1-err)/err or more general:
00251    //            w = ((1-err)/err)^beta
00252    // where err is the fracthin of misclassified events in the tree ( <0.5 assuming
00253    // demanding the that previous selection was better than random guessing)
00254    // and "beta" beeing a free parameter (standard: beta = 1) that modifies the
00255    // boosting.
00256 
00257    Double_t adaBoostBeta = 1.;   // that's apparently the standard value :)
00258 
00259    // should never be called without existing trainingTree
00260    if (Data()->GetNTrainingEvents()) Log() << kFATAL << "<AdaBoost> Data().TrainingTree() is zero pointer" << Endl;
00261 
00262    Double_t err=0, sumw=0, sumwfalse=0, count=0;
00263    vector<Char_t> correctSelected;
00264 
00265    // loop over all events in training tree
00266    MethodBase* mbase = (MethodBase*)method;
00267    for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00268 
00269       Event* ev = Data()->GetEvent(ievt);
00270 
00271       // total sum of event weights
00272       sumw += ev->GetBoostWeight();
00273 
00274       // decide whether it is signal or background-like
00275       Bool_t isSignalType = mbase->IsSignalLike();
00276       
00277       // to prevent code duplication
00278       if (isSignalType == DataInfo().IsSignal(ev))
00279          correctSelected.push_back( kTRUE );
00280       else {
00281          sumwfalse += ev->GetBoostWeight();
00282          count += 1;
00283          correctSelected.push_back( kFALSE );
00284       }
00285    }
00286 
00287    if (0 == sumw) {
00288       Log() << kFATAL << "<AdaBoost> fatal error sum of event boostweights is zero" << Endl;
00289    }
00290 
00291    // compute the boost factor
00292    err = sumwfalse/sumw;
00293 
00294    Double_t newSumw=0;
00295    Int_t i=0;
00296    Double_t boostFactor = 1;
00297    if (err>0){
00298       if (adaBoostBeta == 1){
00299          boostFactor = (1-err)/err ;
00300       }
00301       else {
00302          boostFactor =  TMath::Power((1-err)/err,adaBoostBeta) ;
00303       }
00304    }
00305    else {
00306       boostFactor = 1000; // default
00307    }
00308 
00309    // now fill new boostweights
00310    for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00311 
00312       Event *ev = Data()->GetEvent(ievt);
00313 
00314       // read the Training Event into "event"
00315       if (!correctSelected[ievt]) ev->SetBoostWeight( ev->GetBoostWeight() * boostFactor);
00316 
00317       newSumw += ev->GetBoostWeight();    
00318       i++;
00319    }
00320 
00321    // re-normalise the boostweights
00322    for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00323       Event *ev = Data()->GetEvent(ievt);
00324       ev->SetBoostWeight( ev->GetBoostWeight() * sumw / newSumw );      
00325    }
00326 
00327    fBoostFactorHist->Fill(boostFactor);
00328    fErrFractHist->Fill(GetCommittee().size(),err);
00329 
00330    // save for ntuple
00331    fBoostFactor   = boostFactor;
00332    fErrorFraction = err;
00333   
00334    // return weight factor for this committee member
00335    return TMath::Log(boostFactor);
00336 }
00337 
00338 //_______________________________________________________________________
00339 Double_t TMVA::MethodCommittee::Bagging( UInt_t imember )
00340 {
00341    // call it Bootstrapping, re-sampling or whatever you like, in the end it is nothing
00342    // else but applying "random boostweights" to each event.
00343    Double_t newSumw = 0;
00344    TRandom3* trandom   = new TRandom3( imember );
00345 
00346    // loop over all events in training tree
00347    for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00348       Event* ev = Data()->GetEvent(ievt);
00349 
00350       // read the Training Event into "event"
00351       Double_t newWeight = trandom->Rndm();
00352       ev->SetBoostWeight( newWeight );
00353       newSumw += newWeight;
00354    }
00355 
00356    // re-normalise the boostweights
00357    for (Int_t ievt=0; ievt<Data()->GetNTrainingEvents(); ievt++) {
00358       Event* ev = Data()->GetEvent(ievt);
00359       ev->SetBoostWeight( ev->GetBoostWeight() * Data()->GetNTrainingEvents() / newSumw );      
00360    }
00361 
00362    delete trandom;
00363    // return weight factor for this committee member
00364    return 1.0;  // here as there are random weights for each event, just return a constant==1;
00365 }
00366 
00367 //_______________________________________________________________________
00368 void TMVA::MethodCommittee::AddWeightsXMLTo( void* /*parent*/ ) const {
00369    Log() << kFATAL << "Please implement writing of weights as XML" << Endl;
00370 }
00371   
00372 //_______________________________________________________________________
00373 void  TMVA::MethodCommittee::ReadWeightsFromStream( istream& istr )
00374 {
00375    // read the state of the method from an input stream
00376 
00377    // explicitly destroy objects in vector
00378    std::vector<IMethod*>::iterator member = GetCommittee().begin();
00379    for (; member != GetCommittee().end(); member++) delete *member;
00380 
00381    GetCommittee().clear();
00382    GetBoostWeights().clear();
00383 
00384    TString  dummy;
00385    UInt_t   imember;
00386    Double_t boostWeight;
00387 
00388    DataSetInfo & dsi = DataInfo(); // this needs to be changed for the different kind of committee methods
00389    
00390    // loop over all members in committee
00391    for (UInt_t i=0; i<fNMembers; i++) {
00392        
00393       istr >> dummy >> dummy >> dummy >> imember;
00394       istr >> dummy >> dummy >> boostWeight;
00395 
00396       if (imember != i) {
00397          Log() << kFATAL << "<ReadWeightsFromStream> fatal error while reading Weight file \n "
00398                  << ": mismatch imember: " << imember << " != i: " << i << Endl;
00399       }
00400 
00401       // initialize methods
00402       IMethod* method = ClassifierFactory::Instance().Create(std::string(Types::Instance().GetMethodName( fMemberType )), dsi, "" );
00403 
00404       // read weight file
00405       MethodBase* m = dynamic_cast<MethodBase*>(method);
00406       if(m)
00407          m->ReadStateFromStream(istr);
00408       GetCommittee().push_back(method);
00409       GetBoostWeights().push_back(boostWeight);
00410    }
00411 }
00412 
00413 //_______________________________________________________________________
00414 Double_t TMVA::MethodCommittee::GetMvaValue( Double_t* err, Double_t* errUpper )
00415 {
00416    // return the MVA value (range [-1;1]) that classifies the
00417    // event.according to the majority vote from the total number of
00418    // decision trees
00419    // In the literature I found that people actually use the
00420    // weighted majority vote (using the boost weights) .. However I
00421    // did not see any improvement in doing so :(
00422    // --> this is currently switched off
00423 
00424    // cannot determine error
00425    NoErrorCalc(err, errUpper);
00426 
00427    Double_t myMVA = 0;
00428    Double_t norm  = 0;
00429    for (UInt_t itree=0; itree<GetCommittee().size(); itree++) {
00430 
00431       MethodBase* m = dynamic_cast<MethodBase*>(GetCommittee()[itree]);
00432       if(m==0) continue;
00433 
00434       Double_t tmpMVA = ( fUseMemberDecision ? ( m->IsSignalLike() ? 1.0 : -1.0 ) 
00435                           : GetCommittee()[itree]->GetMvaValue() );
00436 
00437       if (fUseWeightedMembers){
00438          myMVA += GetBoostWeights()[itree] * tmpMVA;
00439          norm  += GetBoostWeights()[itree];
00440       }
00441       else {
00442          myMVA += tmpMVA;
00443          norm  += 1;
00444       }
00445    }
00446    return (norm != 0) ? myMVA /= Double_t(norm) : -999;
00447 }
00448 
00449 //_______________________________________________________________________
00450 void  TMVA::MethodCommittee::WriteMonitoringHistosToFile( void ) const
00451 {
00452    // here we could write some histograms created during the processing
00453    // to the output file.
00454    Log() << kINFO << "Write monitoring histograms to file: " << BaseDir()->GetPath() << Endl;
00455 
00456    fBoostFactorHist->Write();
00457    fErrFractHist->Write();
00458    fMonitorNtuple->Write();
00459 
00460    BaseDir()->cd();
00461 }
00462 
00463 // return the individual relative variable importance 
00464 //_______________________________________________________________________
00465 vector< Double_t > TMVA::MethodCommittee::GetVariableImportance()
00466 {
00467    // return the relative variable importance, normalized to all
00468    // variables together having the importance 1. The importance in
00469    // evaluated as the total separation-gain that this variable had in
00470    // the decision trees (weighted by the number of events)
00471   
00472    fVariableImportance.resize(GetNvar());
00473    //    Double_t  sum=0;
00474    //    for (int itree = 0; itree < fNMembers; itree++){
00475    //       vector<Double_t> relativeImportance(GetCommittee()[itree]->GetVariableImportance());
00476    //       for (unsigned int i=0; i< relativeImportance.size(); i++) {
00477    //          fVariableImportance[i] += relativeImportance[i] ;
00478    //       } 
00479    //    }   
00480    //    for (unsigned int i=0; i< fVariableImportance.size(); i++) sum += fVariableImportance[i];
00481    //    for (unsigned int i=0; i< fVariableImportance.size(); i++) fVariableImportance[i] /= sum;
00482 
00483    return fVariableImportance;
00484 }
00485 
00486 //_______________________________________________________________________
00487 Double_t TMVA::MethodCommittee::GetVariableImportance(UInt_t ivar)
00488 {
00489    // return the variable importance
00490    vector<Double_t> relativeImportance = this->GetVariableImportance();
00491    if (ivar < (UInt_t)relativeImportance.size()) return relativeImportance[ivar];
00492    else  Log() << kFATAL << "<GetVariableImportance> ivar = " << ivar << " is out of range " << Endl;
00493 
00494    return -1;
00495 }
00496 
00497 //_______________________________________________________________________
00498 const TMVA::Ranking* TMVA::MethodCommittee::CreateRanking()
00499 {
00500    // computes ranking of input variables
00501 
00502    // create the ranking object
00503    fRanking = new Ranking( GetName(), "Variable Importance" );
00504    vector< Double_t> importance(this->GetVariableImportance());
00505 
00506    for (UInt_t ivar=0; ivar<GetNvar(); ivar++) {
00507       fRanking->AddRank( Rank( GetInputLabel(ivar), importance[ivar] ) );
00508    }
00509 
00510    return fRanking;
00511 }
00512 
00513 //_______________________________________________________________________
00514 void TMVA::MethodCommittee::MakeClassSpecific( std::ostream& fout, const TString& className ) const
00515 {
00516    // write specific classifier response
00517    fout << "   // not implemented for class: \"" << className << "\"" << endl;
00518    fout << "};" << endl;
00519 }
00520 
00521 //_______________________________________________________________________
00522 void TMVA::MethodCommittee::GetHelpMessage() const
00523 {
00524    // get help message text
00525    //
00526    // typical length of text line: 
00527    //         "|--------------------------------------------------------------|"
00528    Log() << Endl;
00529    Log() << gTools().Color("bold") << "--- Short description:" << gTools().Color("reset") << Endl;
00530    Log() << Endl;
00531    Log() << "<None>" << Endl;
00532    Log() << Endl;
00533    Log() << gTools().Color("bold") << "--- Performance optimisation:" << gTools().Color("reset") << Endl;
00534    Log() << Endl;
00535    Log() << "<None>" << Endl;
00536    Log() << Endl;
00537    Log() << gTools().Color("bold") << "--- Performance tuning via configuration options:" << gTools().Color("reset") << Endl;
00538    Log() << Endl;
00539    Log() << "<None>" << Endl;
00540 }

Generated on Tue Jul 5 15:25:02 2011 for ROOT_528-00b_version by  doxygen 1.5.1