00001
00002
00003
00004
00005
00006
00007
00008
00009
00010 #include <cstdlib>
00011 #include <vector>
00012 #include <iostream>
00013 #include <map>
00014 #include <string>
00015
00016 #include "TFile.h"
00017 #include "TTree.h"
00018 #include "TString.h"
00019 #include "TSystem.h"
00020 #include "TROOT.h"
00021 #include "TStopwatch.h"
00022
00023 #include "TMVA/Reader.h"
00024 #include "TMVA/Tools.h"
00025 #include "TMVA/MethodCuts.h"
00026
00027
00028 Bool_t ReadDataFromAsciiIFormat = kFALSE;
00029
00030 int main( int argc, char** argv )
00031 {
00032
00033
00034 std::map<std::string,int> Use;
00035
00036
00037 Use["PDERS"] = 0;
00038 Use["PDEFoam"] = 1;
00039 Use["KNN"] = 1;
00040
00041
00042 Use["LD"] = 1;
00043
00044
00045 Use["FDA_GA"] = 1;
00046 Use["FDA_MC"] = 0;
00047 Use["FDA_MT"] = 0;
00048 Use["FDA_GAMT"] = 0;
00049
00050
00051 Use["MLP"] = 1;
00052
00053
00054 Use["SVM"] = 0;
00055
00056
00057 Use["BDT"] = 0;
00058 Use["BDTG"] = 1;
00059
00060
00061 std::cout << std::endl;
00062 std::cout << "==> Start TMVARegressionApplication" << std::endl;
00063
00064 std::cout << "Running the following methods" << std::endl;
00065 if (argc>1) {
00066 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
00067 }
00068 for (int i=1; i<argc; i++) {
00069 std::string regMethod(argv[i]);
00070 if (Use.find(regMethod) == Use.end()) {
00071 std::cout << "Method " << regMethod << " not known in TMVA under this name. Please try one of:" << std::endl;
00072 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
00073 std::cout << std::endl;
00074 return 1;
00075 }
00076 Use[regMethod] = kTRUE;
00077 }
00078
00079
00080
00081
00082
00083
00084 TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );
00085
00086
00087
00088 Float_t var1, var2;
00089 reader->AddVariable( "var1", &var1 );
00090 reader->AddVariable( "var2", &var2 );
00091
00092
00093 Float_t spec1,spec2;
00094 reader->AddSpectator( "spec1:=var1*2", &spec1 );
00095 reader->AddSpectator( "spec2:=var1*3", &spec2 );
00096
00097
00098
00099 TString dir = "weights/";
00100 TString prefix = "TMVARegression";
00101
00102
00103 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
00104 if (it->second) {
00105 TString methodName = it->first + " method";
00106 TString weightfile = dir + prefix + "_" + TString(it->first) + ".weights.xml";
00107 reader->BookMVA( methodName, weightfile );
00108 }
00109 }
00110
00111
00112 TH1* hists[100];
00113 Int_t nhists = -1;
00114 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
00115 TH1* h = new TH1F( it->first.c_str(), TString(it->first) + " method", 100, -100, 600 );
00116 if (it->second) hists[++nhists] = h;
00117 }
00118 nhists++;
00119
00120
00121
00122
00123
00124 TFile *input(0);
00125 TString fname = "./tmva_reg_example.root";
00126 if (!gSystem->AccessPathName( fname )) {
00127 input = TFile::Open( fname );
00128 }
00129 else {
00130 input = TFile::Open( "http://root.cern.ch/files/tmva_reg_example.root" );
00131 }
00132
00133 if (!input) {
00134 std::cout << "ERROR: could not open data file" << std::endl;
00135 exit(1);
00136 }
00137 std::cout << "--- TMVARegressionApp : Using input file: " << input->GetName() << std::endl;
00138
00139
00140
00141
00142
00143
00144
00145
00146 TTree* theTree = (TTree*)input->Get("TreeR");
00147 std::cout << "--- Select signal sample" << std::endl;
00148 theTree->SetBranchAddress( "var1", &var1 );
00149 theTree->SetBranchAddress( "var2", &var2 );
00150
00151 std::cout << "--- Processing: " << theTree->GetEntries() << " events" << std::endl;
00152 TStopwatch sw;
00153 sw.Start();
00154 for (Long64_t ievt=0; ievt<theTree->GetEntries();ievt++) {
00155
00156 if (ievt%1000 == 0) {
00157 std::cout << "--- ... Processing event: " << ievt << std::endl;
00158 }
00159
00160 theTree->GetEntry(ievt);
00161
00162
00163
00164
00165 for (Int_t ih=0; ih<nhists; ih++) {
00166 TString title = hists[ih]->GetTitle();
00167 Float_t val = (reader->EvaluateRegression( title ))[0];
00168 hists[ih]->Fill( val );
00169 }
00170 }
00171 sw.Stop();
00172 std::cout << "--- End of event loop: "; sw.Print();
00173
00174
00175
00176 TFile *target = new TFile( "TMVARegApp.root","RECREATE" );
00177 for (Int_t ih=0; ih<nhists; ih++) hists[ih]->Write();
00178 target->Close();
00179
00180 std::cout << "--- Created root file: \"" << target->GetName()
00181 << "\" containing the MVA output histograms" << std::endl;
00182
00183 delete reader;
00184
00185 std::cout << "==> TMVARegressionApplication is done!" << std::endl << std::endl;
00186 }