TMVAMulticlassApplication.cxx

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project   : TMVA - a Root-integrated toolkit for multivariate data analysis    *
00003  * Package   : TMVA                                                               *
00004  * Exectuable: TMVAMulticlassApplication                                          *
00005  *                                                                                *
00006  * This macro provides a simple example on how to use the trained multiclass      *
00007  * classifiers within an analysis module                                          *
00008  **********************************************************************************/
00009 
00010 #include <cstdlib>
00011 #include <iostream>
00012 #include <map>
00013 #include <string>
00014 
00015 #include "TFile.h"
00016 #include "TTree.h"
00017 #include "TString.h"
00018 #include "TSystem.h"
00019 #include "TROOT.h"
00020 #include "TStopwatch.h"
00021 #include "TH1F.h"
00022 
00023 #include "TMVA/Tools.h"
00024 #include "TMVA/Reader.h"
00025 #include "TMVA/Config.h"
00026 
00027 using namespace TMVA;
00028 
00029 int main(int argc, char** argv )
00030 {
00031    TMVA::Tools::Instance();
00032    
00033    //---------------------------------------------------------------
00034    // default MVA methods to be trained + tested
00035    std::map<std::string,int> Use;
00036    Use["MLP"]             = 1;
00037    Use["BDTG"]            = 1;
00038    Use["FDA_GA"]          = 0;
00039    //---------------------------------------------------------------
00040   
00041    std::cout << std::endl;
00042    std::cout << "==> Start TMVAMulticlassApplication" << std::endl; 
00043 
00044    if (argc>1) {
00045       for (std::map<std::string,int>::iterator it = Use.begin();
00046            it != Use.end(); it++) {
00047          it->second = 0;
00048       }
00049    }
00050    for (int i=1; i<argc; i++) {
00051       std::string regMethod(argv[i]);
00052       if (Use.find(regMethod) == Use.end()) {
00053          std::cout << "Method " << regMethod << " not known in TMVA under this name. Please try one of:" << std::endl;
00054          for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
00055          std::cout << std::endl;
00056          return 1;
00057       }
00058       Use[regMethod] = kTRUE;
00059    }
00060 
00061    
00062    // create the Reader object
00063    TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );    
00064 
00065    // create a set of variables and declare them to the reader
00066    // - the variable names must corresponds in name and type to 
00067    // those given in the weight file(s) that you use
00068    Float_t var1, var2, var3, var4;
00069    reader->AddVariable( "var1", &var1 );
00070    reader->AddVariable( "var2", &var2 );
00071    reader->AddVariable( "var3", &var3 );
00072    reader->AddVariable( "var4", &var4 );
00073 
00074    // book the MVA methods
00075    TString dir    = "weights/";
00076    TString prefix = "TMVAMulticlass";
00077    
00078    for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
00079       if (it->second) {
00080          TString methodName = it->first + " method";
00081          TString weightfile = dir + prefix + "_" + TString(it->first) + ".weights.xml";
00082          reader->BookMVA( methodName, weightfile ); 
00083       }
00084    }
00085 
00086    // book output histograms
00087    UInt_t nbin = 100;
00088    TH1F *histMLP_signal(0), *histBDTG_signal(0), *histFDAGA_signal(0);
00089    if (Use["MLP"])    
00090       histMLP_signal    = new TH1F( "MVA_MLP_signal",    "MVA_MLP_signal",    nbin, 0., 1.1 );
00091    if (Use["BDTG"])
00092       histBDTG_signal  = new TH1F( "MVA_BDTG_signal",   "MVA_BDTG_signal",   nbin, 0., 1.1 );
00093    if (Use["FDA_GA"])
00094       histFDAGA_signal = new TH1F( "MVA_FDA_GA_signal", "MVA_FDA_GA_signal", nbin, 0., 1.1 );
00095 
00096 
00097    TFile *input(0);
00098    TString fname = "./tmva_example_multiple_background.root";
00099    if (!gSystem->AccessPathName( fname )) {
00100       input = TFile::Open( fname ); // check if file in local directory exists
00101    }
00102    if (!input) {
00103       std::cout << "ERROR: could not open data file, please generate example data first!" << std::endl;
00104       exit(1);
00105    }
00106    std::cout << "--- TMVAMulticlassApp : Using input file: " << input->GetName() << std::endl;
00107    
00108    // prepare the tree
00109    // - here the variable names have to corresponds to your tree
00110    // - you can use the same variables as above which is slightly faster,
00111    //   but of course you can use different ones and copy the values inside the event loop
00112   
00113    TTree* theTree = (TTree*)input->Get("TreeS");
00114    std::cout << "--- Select signal sample" << std::endl;
00115    theTree->SetBranchAddress( "var1", &var1 );
00116    theTree->SetBranchAddress( "var2", &var2 );
00117    theTree->SetBranchAddress( "var3", &var3 );
00118    theTree->SetBranchAddress( "var4", &var4 );
00119 
00120    std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
00121    TStopwatch sw;
00122    sw.Start();
00123 
00124    for (Long64_t ievt=0; ievt<theTree->GetEntries();ievt++) {
00125       if (ievt%1000 == 0){
00126          std::cout << "--- ... Processing event: " << ievt << std::endl;
00127       }
00128       
00129       theTree->GetEntry(ievt);
00130       if (Use["MLP"])
00131          histMLP_signal->Fill((reader->EvaluateMulticlass( "MLP method" ))[0]);
00132       if (Use["BDTG"])
00133          histBDTG_signal->Fill((reader->EvaluateMulticlass( "BDTG method" ))[0]);
00134       if (Use["FDA_GA"])
00135          histFDAGA_signal->Fill((reader->EvaluateMulticlass( "FDA_GA method" ))[0]);
00136     
00137       
00138    }
00139    
00140    // get elapsed time
00141    sw.Stop();
00142    std::cout << "--- End of event loop: "; sw.Print();
00143    
00144    TFile *target  = new TFile( "TMVAMulticlassApp.root","RECREATE" );
00145    if (Use["MLP"])
00146       histMLP_signal->Write();
00147    if (Use["BDTG"])
00148       histBDTG_signal->Write(); 
00149    if (Use["FDA_GA"])
00150       histFDAGA_signal->Write();
00151 
00152 
00153    target->Close();
00154    std::cout << "--- Created root file: \"TMVMulticlassApp.root\" containing the MVA output histograms" << std::endl;
00155 
00156    delete reader;
00157    
00158    std::cout << "==> TMVAClassificationApplication is done!" << std::endl << std::endl;
00159 }

Generated on Tue Jul 5 15:26:37 2011 for ROOT_528-00b_version by  doxygen 1.5.1