TMVAMulticlassApplication.C

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project   : TMVA - a Root-integrated toolkit for multivariate data analysis    *
00003  * Package   : TMVA                                                               *
00004  * Root Macro: TMVAMulticlassApplication                                          *
00005  *                                                                                *
00006  * This macro provides a simple example on how to use the trained multiclass      *
00007  * classifiers within an analysis module                                          *
00008  **********************************************************************************/
00009 
00010 #include <cstdlib>
00011 #include <iostream>
00012 #include <map>
00013 #include <string>
00014 #include <vector>
00015 
00016 #include "TFile.h"
00017 #include "TTree.h"
00018 #include "TString.h"
00019 #include "TSystem.h"
00020 #include "TROOT.h"
00021 #include "TStopwatch.h"
00022 #include "TH1F.h"
00023 
00024 #if not defined(__CINT__) || defined(__MAKECINT__)
00025 #include "TMVA/Tools.h"
00026 #include "TMVA/Reader.h"
00027 #endif
00028 
00029 using namespace TMVA;
00030 
00031 void TMVAMulticlassApplication( TString myMethodList = "" )
00032 {
00033 #ifdef __CINT__
00034    gROOT->ProcessLine( ".O0" ); // turn off optimization in CINT
00035 #endif
00036 
00037    TMVA::Tools::Instance();
00038    
00039    //---------------------------------------------------------------
00040    // default MVA methods to be trained + tested
00041    std::map<std::string,int> Use;
00042    Use["MLP"]             = 1;
00043    Use["BDTG"]            = 1;
00044    Use["FDA_GA"]          = 0;
00045    //---------------------------------------------------------------
00046   
00047    std::cout << std::endl;
00048    std::cout << "==> Start TMVAMulticlassApplication" << std::endl; 
00049    if (myMethodList != "") {
00050       for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
00051 
00052       std::vector<TString> mlist = gTools().SplitString( myMethodList, ',' );
00053       for (UInt_t i=0; i<mlist.size(); i++) {
00054          std::string regMethod(mlist[i]);
00055 
00056          if (Use.find(regMethod) == Use.end()) {
00057             std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
00058             for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " " << std::endl;
00059             std::cout << std::endl;
00060             return;
00061          }
00062          Use[regMethod] = 1;
00063       }
00064    }
00065 
00066    
00067    // create the Reader object
00068    TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );    
00069 
00070    // create a set of variables and declare them to the reader
00071    // - the variable names must corresponds in name and type to 
00072    // those given in the weight file(s) that you use
00073    Float_t var1, var2, var3, var4;
00074    reader->AddVariable( "var1", &var1 );
00075    reader->AddVariable( "var2", &var2 );
00076    reader->AddVariable( "var3", &var3 );
00077    reader->AddVariable( "var4", &var4 );
00078 
00079    // book the MVA methods
00080    TString dir    = "weights/";
00081    TString prefix = "TMVAMulticlass";
00082    
00083    for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
00084       if (it->second) {
00085         TString methodName = TString(it->first) + TString(" method");
00086         TString weightfile = dir + prefix + TString("_") + TString(it->first) + TString(".weights.xml"); 
00087         reader->BookMVA( methodName, weightfile ); 
00088       }
00089    }
00090 
00091    // book output histograms
00092    UInt_t nbin = 100;
00093    TH1F *histMLP_signal(0), *histBDTG_signal(0), *histFDAGA_signal(0);
00094    if (Use["MLP"])    
00095       histMLP_signal    = new TH1F( "MVA_MLP_signal",    "MVA_MLP_signal",    nbin, 0., 1.1 );
00096    if (Use["BDTG"])
00097       histBDTG_signal  = new TH1F( "MVA_BDTG_signal",   "MVA_BDTG_signal",   nbin, 0., 1.1 );
00098    if (Use["FDA_GA"])
00099       histFDAGA_signal = new TH1F( "MVA_FDA_GA_signal", "MVA_FDA_GA_signal", nbin, 0., 1.1 );
00100 
00101    TFile *input(0); 
00102    TString fname = "./tmva_example_multiple_background.root";
00103    if (!gSystem->AccessPathName( fname )) {
00104       input = TFile::Open( fname ); // check if file in local directory exists
00105    }
00106    if (!input) {
00107       std::cout << "ERROR: could not open data file, please generate example data first!" << std::endl;
00108       exit(1);
00109    }
00110    std::cout << "--- TMVAMulticlassApp : Using input file: " << input->GetName() << std::endl;
00111    
00112    // prepare the tree
00113    // - here the variable names have to corresponds to your tree
00114    // - you can use the same variables as above which is slightly faster,
00115    //   but of course you can use different ones and copy the values inside the event loop
00116   
00117    TTree* theTree = (TTree*)input->Get("TreeS");
00118    std::cout << "--- Select signal sample" << std::endl;
00119    theTree->SetBranchAddress( "var1", &var1 );
00120    theTree->SetBranchAddress( "var2", &var2 );
00121    theTree->SetBranchAddress( "var3", &var3 );
00122    theTree->SetBranchAddress( "var4", &var4 );
00123 
00124    std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
00125    TStopwatch sw;
00126    sw.Start();
00127 
00128    for (Long64_t ievt=0; ievt<theTree->GetEntries();ievt++) {
00129       if (ievt%1000 == 0){
00130          std::cout << "--- ... Processing event: " << ievt << std::endl;
00131       }
00132       theTree->GetEntry(ievt);
00133       
00134       if (Use["MLP"])
00135          histMLP_signal->Fill((reader->EvaluateMulticlass( "MLP method" ))[0]);
00136       if (Use["BDTG"])
00137          histBDTG_signal->Fill((reader->EvaluateMulticlass( "BDTG method" ))[0]);
00138       if (Use["FDA_GA"])
00139          histFDAGA_signal->Fill((reader->EvaluateMulticlass( "FDA_GA method" ))[0]);
00140       
00141    }
00142    
00143    // get elapsed time
00144    sw.Stop();
00145    std::cout << "--- End of event loop: "; sw.Print();
00146    
00147    TFile *target  = new TFile( "TMVAMulticlassApp.root","RECREATE" );
00148    if (Use["MLP"])
00149       histMLP_signal->Write();
00150    if (Use["BDTG"])
00151       histBDTG_signal->Write(); 
00152    if (Use["FDA_GA"])
00153       histFDAGA_signal->Write();
00154 
00155    target->Close();
00156    std::cout << "--- Created root file: \"TMVMulticlassApp.root\" containing the MVA output histograms" << std::endl;
00157 
00158    delete reader;
00159    
00160    std::cout << "==> TMVAClassificationApplication is done!" << std::endl << std::endl;
00161 }

Generated on Tue Jul 5 15:26:37 2011 for ROOT_528-00b_version by  doxygen 1.5.1