TMVAMulticlass.C

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project   : TMVA - a Root-integrated toolkit for multivariate data analysis    *
00003  * Package   : TMVA                                                               *
00004  * Root Macro: TMVAMulticlass                                                     *
00005  *                                                                                *
00006  * This macro provides a simple example for the training and testing of the TMVA  *
00007  * multiclass classification                                                      *
00008  **********************************************************************************/
00009 
00010 #include <cstdlib>
00011 #include <iostream>
00012 #include <map>
00013 #include <string>
00014 
00015 #include "TFile.h"
00016 #include "TTree.h"
00017 #include "TString.h"
00018 #include "TSystem.h"
00019 #include "TROOT.h"
00020 
00021 #include "TMVAMultiClassGui.C"
00022 
00023 #ifndef __CINT__
00024 #include "TMVA/Tools.h"
00025 #include "TMVA/Factory.h"
00026 #endif
00027 
00028 using namespace TMVA;
00029 
00030 void TMVAMulticlass( TString myMethodList = "" )
00031 {
00032    
00033    TMVA::Tools::Instance();
00034    
00035    //---------------------------------------------------------------
00036    // default MVA methods to be trained + tested
00037    std::map<std::string,int> Use;
00038    Use["MLP"]             = 1;
00039    Use["BDTG"]            = 1;
00040    Use["FDA_GA"]          = 0;
00041    //---------------------------------------------------------------
00042    
00043    std::cout << std::endl;
00044    std::cout << "==> Start TMVAMulticlass" << std::endl;
00045    
00046    if (myMethodList != "") {
00047       for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
00048       
00049       std::vector<TString> mlist = TMVA::gTools().SplitString( myMethodList, ',' );
00050       for (UInt_t i=0; i<mlist.size(); i++) {
00051          std::string regMethod(mlist[i]);
00052 
00053          if (Use.find(regMethod) == Use.end()) {
00054             std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
00055             for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
00056             std::cout << std::endl;
00057             return;
00058          }
00059          Use[regMethod] = 1;
00060       }
00061    }
00062 
00063    // Create a new root output file.
00064    TString outfileName = "TMVAMulticlass.root";
00065    TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
00066    
00067    TMVA::Factory *factory = new TMVA::Factory( "TMVAMulticlass", outputFile,
00068                                                "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=multiclass" );
00069    factory->AddVariable( "var1", 'F' );
00070    factory->AddVariable( "var2", "Variable 2", "", 'F' );
00071    factory->AddVariable( "var3", "Variable 3", "units", 'F' );
00072    factory->AddVariable( "var4", "Variable 4", "units", 'F' );
00073 
00074    TFile *input(0);
00075    TString fname = "./tmva_example_multiple_background.root";
00076    if (!gSystem->AccessPathName( fname )) {
00077       // first we try to find the file in the local directory
00078       std::cout << "--- TMVAMulticlass   : Accessing " << fname << std::endl;
00079       input = TFile::Open( fname );
00080    }
00081    else {
00082       cout << "Creating testdata...." << std::endl;
00083       gROOT->ProcessLine(".L createData.C+");
00084       gROOT->ProcessLine("create_MultipleBackground(2000)");
00085       cout << " created tmva_example_multiple_background.root for tests of the multiclass features"<<endl;
00086       input = TFile::Open( fname );
00087    }
00088    if (!input) {
00089       std::cout << "ERROR: could not open data file" << std::endl;
00090       exit(1);
00091    }
00092 
00093    TTree *signal      = (TTree*)input->Get("TreeS");
00094    TTree *background0 = (TTree*)input->Get("TreeB0");
00095    TTree *background1 = (TTree*)input->Get("TreeB1");
00096    TTree *background2 = (TTree*)input->Get("TreeB2");
00097    
00098    gROOT->cd( outfileName+TString(":/") );
00099    factory->AddTree    (signal,"Signal");
00100    factory->AddTree    (background0,"bg0");
00101    factory->AddTree    (background1,"bg1");
00102    factory->AddTree    (background2,"bg2");
00103    
00104    factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:NormMode=NumEvents:!V" );
00105 
00106    if (Use["BDTG"]) // gradient boosted decision trees
00107       factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.10:UseBaggedGrad:GradBaggingFraction=0.50:nCuts=20:NNodesMax=8");
00108    if (Use["MLP"]) // neural network
00109       factory->BookMethod( TMVA::Types::kMLP, "MLP", "!H:!V:NeuronType=tanh:NCycles=1000:HiddenLayers=N+5,5:TestRate=5:EstimatorType=MSE");
00110    if (Use["FDA_GA"]) // functional discriminant with GA minimizer
00111       factory->BookMethod( TMVA::Types::kFDA, "FDA_GA", "H:!V:Formula=(0)+(1)*x0+(2)*x1+(3)*x2+(4)*x3:ParRanges=(-1,1);(-10,10);(-10,10);(-10,10);(-10,10):FitMethod=GA:PopSize=300:Cycles=3:Steps=20:Trim=True:SaveBestGen=1" );
00112    
00113   // Train MVAs using the set of training events
00114    factory->TrainAllMethods();
00115 
00116    // ---- Evaluate all MVAs using the set of test events
00117    factory->TestAllMethods();
00118 
00119    // ----- Evaluate and compare performance of all configured MVAs
00120    factory->EvaluateAllMethods();
00121 
00122    // --------------------------------------------------------------
00123    
00124    // Save the output
00125    outputFile->Close();
00126    
00127    std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
00128    std::cout << "==> TMVAClassification is done!" << std::endl;
00129    
00130    delete factory;
00131    
00132    // Launch the GUI for the root macros
00133    if (!gROOT->IsBatch()) TMVAMultiClassGui( outfileName );
00134    
00135    
00136 }

Generated on Tue Jul 5 15:26:37 2011 for ROOT_528-00b_version by  doxygen 1.5.1