00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <cstdlib>
00011 #include <iostream>
00012 #include <map>
00013 #include <string>
00014
00015 #include "TFile.h"
00016 #include "TTree.h"
00017 #include "TString.h"
00018 #include "TSystem.h"
00019 #include "TROOT.h"
00020
00021 #include "TMVAMultiClassGui.C"
00022
00023 #ifndef __CINT__
00024 #include "TMVA/Tools.h"
00025 #include "TMVA/Factory.h"
00026 #endif
00027
00028 using namespace TMVA;
00029
00030 void TMVAMulticlass( TString myMethodList = "" )
00031 {
00032
00033 TMVA::Tools::Instance();
00034
00035
00036
00037 std::map<std::string,int> Use;
00038 Use["MLP"] = 1;
00039 Use["BDTG"] = 1;
00040 Use["FDA_GA"] = 0;
00041
00042
00043 std::cout << std::endl;
00044 std::cout << "==> Start TMVAMulticlass" << std::endl;
00045
00046 if (myMethodList != "") {
00047 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
00048
00049 std::vector<TString> mlist = TMVA::gTools().SplitString( myMethodList, ',' );
00050 for (UInt_t i=0; i<mlist.size(); i++) {
00051 std::string regMethod(mlist[i]);
00052
00053 if (Use.find(regMethod) == Use.end()) {
00054 std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
00055 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
00056 std::cout << std::endl;
00057 return;
00058 }
00059 Use[regMethod] = 1;
00060 }
00061 }
00062
00063
00064 TString outfileName = "TMVAMulticlass.root";
00065 TFile* outputFile = TFile::Open( outfileName, "RECREATE" );
00066
00067 TMVA::Factory *factory = new TMVA::Factory( "TMVAMulticlass", outputFile,
00068 "!V:!Silent:Color:DrawProgressBar:Transformations=I;D;P;G,D:AnalysisType=multiclass" );
00069 factory->AddVariable( "var1", 'F' );
00070 factory->AddVariable( "var2", "Variable 2", "", 'F' );
00071 factory->AddVariable( "var3", "Variable 3", "units", 'F' );
00072 factory->AddVariable( "var4", "Variable 4", "units", 'F' );
00073
00074 TFile *input(0);
00075 TString fname = "./tmva_example_multiple_background.root";
00076 if (!gSystem->AccessPathName( fname )) {
00077
00078 std::cout << "--- TMVAMulticlass : Accessing " << fname << std::endl;
00079 input = TFile::Open( fname );
00080 }
00081 else {
00082 cout << "Creating testdata...." << std::endl;
00083 gROOT->ProcessLine(".L createData.C+");
00084 gROOT->ProcessLine("create_MultipleBackground(2000)");
00085 cout << " created tmva_example_multiple_background.root for tests of the multiclass features"<<endl;
00086 input = TFile::Open( fname );
00087 }
00088 if (!input) {
00089 std::cout << "ERROR: could not open data file" << std::endl;
00090 exit(1);
00091 }
00092
00093 TTree *signal = (TTree*)input->Get("TreeS");
00094 TTree *background0 = (TTree*)input->Get("TreeB0");
00095 TTree *background1 = (TTree*)input->Get("TreeB1");
00096 TTree *background2 = (TTree*)input->Get("TreeB2");
00097
00098 gROOT->cd( outfileName+TString(":/") );
00099 factory->AddTree (signal,"Signal");
00100 factory->AddTree (background0,"bg0");
00101 factory->AddTree (background1,"bg1");
00102 factory->AddTree (background2,"bg2");
00103
00104 factory->PrepareTrainingAndTestTree( "", "SplitMode=Random:NormMode=NumEvents:!V" );
00105
00106 if (Use["BDTG"])
00107 factory->BookMethod( TMVA::Types::kBDT, "BDTG", "!H:!V:NTrees=1000:BoostType=Grad:Shrinkage=0.10:UseBaggedGrad:GradBaggingFraction=0.50:nCuts=20:NNodesMax=8");
00108 if (Use["MLP"])
00109 factory->BookMethod( TMVA::Types::kMLP, "MLP", "!H:!V:NeuronType=tanh:NCycles=1000:HiddenLayers=N+5,5:TestRate=5:EstimatorType=MSE");
00110 if (Use["FDA_GA"])
00111 factory->BookMethod( TMVA::Types::kFDA, "FDA_GA", "H:!V:Formula=(0)+(1)*x0+(2)*x1+(3)*x2+(4)*x3:ParRanges=(-1,1);(-10,10);(-10,10);(-10,10);(-10,10):FitMethod=GA:PopSize=300:Cycles=3:Steps=20:Trim=True:SaveBestGen=1" );
00112
00113
00114 factory->TrainAllMethods();
00115
00116
00117 factory->TestAllMethods();
00118
00119
00120 factory->EvaluateAllMethods();
00121
00122
00123
00124
00125 outputFile->Close();
00126
00127 std::cout << "==> Wrote root file: " << outputFile->GetName() << std::endl;
00128 std::cout << "==> TMVAClassification is done!" << std::endl;
00129
00130 delete factory;
00131
00132
00133 if (!gROOT->IsBatch()) TMVAMultiClassGui( outfileName );
00134
00135
00136 }