00001
00002
00003
00004
00005
00006
00007
00008
00009 #include <cstdlib>
00010 #include <vector>
00011 #include <iostream>
00012 #include <map>
00013 #include <string>
00014
00015 #include "TFile.h"
00016 #include "TTree.h"
00017 #include "TString.h"
00018 #include "TSystem.h"
00019 #include "TROOT.h"
00020 #include "TStopwatch.h"
00021 #include "TGraph.h"
00022 #include "TH2F.h"
00023 #include "TH3F.h"
00024 #include "TMVAGui.C"
00025
00026 #if not defined(__CINT__) || defined(__MAKECINT__)
00027 #include "TMVA/Tools.h"
00028 #include "TMVA/Reader.h"
00029 #include "TMVA/MethodCuts.h"
00030 #endif
00031
00032 using namespace TMVA;
00033
00034
00035 void plot(TH2D *sig, TH2D *bkg, TH2F *MVA, TString v0="var0", TString v1="var1"){
00036
00037 TCanvas *c = new TCanvas(Form("DecisionBoundary%s",MVA->GetTitle()),MVA->GetTitle(),800,800);
00038
00039 gStyle->SetPalette(1);
00040 MVA->SetXTitle(v0);
00041 MVA->SetYTitle(v1);
00042 MVA->SetStats(0);
00043 MVA->Draw("cont1");
00044 sig->SetMarkerColor(2);
00045 bkg->SetMarkerColor(4);
00046 sig->SetMarkerStyle(20);
00047 bkg->SetMarkerStyle(20);
00048 sig->SetMarkerSize(.5);
00049 bkg->SetMarkerSize(.5);
00050 sig->Draw("same");
00051 bkg->Draw("same");
00052 }
00053
00054
00055 void PlotDecisionBoundary( TString myMethodList = "",TString v0="var0", TString v1="var1", TString dataFileName = "/home/hvoss/TMVA/TMVA_data/data/data_3Bumps.root", TString weightFilePrefix="TMVA")
00056 {
00057
00058
00059
00060
00061 TMVA::Tools::Instance();
00062
00063 std::map<std::string,int> Use;
00064
00065 Use["CutsGA"] = 0;
00066
00067 Use["Likelihood"] = 0;
00068 Use["LikelihoodD"] = 0;
00069 Use["LikelihoodPCA"] = 0;
00070 Use["LikelihoodKDE"] = 0;
00071 Use["LikelihoodMIX"] = 0;
00072
00073 Use["PDERS"] = 0;
00074 Use["PDERSD"] = 0;
00075 Use["PDERSPCA"] = 0;
00076 Use["PDERSkNN"] = 0;
00077 Use["PDEFoam"] = 0;
00078
00079 Use["KNN"] = 0;
00080
00081 Use["HMatrix"] = 0;
00082 Use["Fisher"] = 0;
00083 Use["FisherG"] = 0;
00084 Use["BoostedFisher"] = 0;
00085 Use["LD"] = 0;
00086
00087 Use["FDA_GA"] = 0;
00088 Use["FDA_SA"] = 0;
00089 Use["FDA_MC"] = 0;
00090 Use["FDA_MT"] = 0;
00091 Use["FDA_GAMT"] = 0;
00092 Use["FDA_MCMT"] = 0;
00093
00094 Use["MLP"] = 0;
00095 Use["MLPBFGS"] = 0;
00096 Use["CFMlpANN"] = 0;
00097 Use["TMlpANN"] = 0;
00098
00099 Use["SVM"] = 0;
00100
00101 Use["BDT"] = 0;
00102 Use["BDTD"] = 0;
00103 Use["BDTG"] = 0;
00104 Use["BDTB"] = 0;
00105
00106 Use["RuleFit"] = 0;
00107
00108 Use["Category"] = 0;
00109
00110 Use["Plugin"] = 0;
00111
00112
00113 std::cout << std::endl;
00114 std::cout << "==> Start TMVAClassificationApplication" << std::endl;
00115
00116 if (myMethodList != "") {
00117 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) it->second = 0;
00118
00119 std::vector<TString> mlist = gTools().SplitString( myMethodList, ',' );
00120 for (UInt_t i=0; i<mlist.size(); i++) {
00121 std::string regMethod(mlist[i]);
00122
00123 if (Use.find(regMethod) == Use.end()) {
00124 std::cout << "Method \"" << regMethod << "\" not known in TMVA under this name. Choose among the following:" << std::endl;
00125 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) std::cout << it->first << " ";
00126 std::cout << std::endl;
00127 return;
00128 }
00129 Use[regMethod] = 1;
00130 }
00131 }
00132
00133
00134
00135
00136 TMVA::Reader *reader = new TMVA::Reader( "!Color:!Silent" );
00137
00138
00139
00140
00141 Float_t var0, var1;
00142 reader->AddVariable( v0, &var0 );
00143 reader->AddVariable( v1, &var1 );
00144
00145
00146
00147
00148 TString dir = "weights/";
00149 TString prefix = weightFilePrefix;
00150
00151
00152 for (std::map<std::string,int>::iterator it = Use.begin(); it != Use.end(); it++) {
00153 if (it->second) {
00154 TString methodName = it->first + " method";
00155 TString weightfile = dir + prefix + "_" + TString(it->first) + ".weights.xml";
00156 reader->BookMVA( methodName, weightfile );
00157 }
00158 }
00159
00160 TFile *f = new TFile(dataFileName);
00161 TTree *signal = (TTree*)f->Get("TreeS");
00162 TTree *background = (TTree*)f->Get("TreeB");
00163
00164
00165
00166
00167
00168 Float_t svar0;
00169 Float_t svar1;
00170 Float_t bvar0;
00171 Float_t bvar1;
00172
00173
00174 signal->SetBranchAddress(v0,&svar0);
00175 signal->SetBranchAddress(v1,&svar1);
00176 background->SetBranchAddress(v0,&bvar0);
00177 background->SetBranchAddress(v1,&bvar1);
00178
00179
00180 UInt_t nbin = 50;
00181 Float_t xmax = signal->GetMaximum(v0.Data());
00182 Float_t xmin = signal->GetMinimum(v0.Data());
00183 Float_t ymax = signal->GetMaximum(v1.Data());
00184 Float_t ymin = signal->GetMinimum(v1.Data());
00185
00186 TH2D *hs=new TH2D("hs","",nbin,xmin,xmax,nbin,ymin,ymax);
00187 TH2D *hb=new TH2D("hb","",nbin,xmin,xmax,nbin,ymin,ymax);
00188
00189
00190 Long64_t nentries;
00191 nentries = TreeS->GetEntries();
00192 for (Long64_t is=0; is<nentries;is++) {
00193 signal->GetEntry(is);
00194 hs->Fill(svar0,svar1);
00195 }
00196 nentries = TreeB->GetEntries();
00197 for (Long64_t ib=0; ib<nentries;ib++) {
00198 background->GetEntry(ib);
00199 hb->Fill(bvar0,bvar1);
00200 }
00201
00202
00203 hb->SetMarkerColor(4);
00204 hs->SetMarkerColor(2);
00205
00206
00207
00208 TH2F *histLk(0), *histLkD(0), *histLkPCA(0), *histLkKDE(0), *histLkMIX(0), *histPD(0), *histPDD(0);
00209 TH2F *histPDPCA(0), *histPDEFoam(0), *histPDEFoamErr(0), *histPDEFoamSig(0), *histKNN(0), *histHm(0);
00210 TH2F *histFi(0), *histFiG(0), *histFiB(0), *histLD(0), *histNn(0), *histNnC(0), *histNnT(0), *histBdt(0), *histBdtG(0), *histBdtD(0);
00211 TH2F *histRf(0), *histSVMG(0), *histSVMP(0), *histSVML(0), *histFDAMT(0), *histFDAGA(0), *histCat(0), *histPBdt(0);
00212
00213 if (Use["Likelihood"]) histLk = new TH2F( "MVA_Likelihood", "MVA_Likelihood", nbin,xmin,xmax,nbin,ymin,ymax);
00214 if (Use["LikelihoodD"]) histLkD = new TH2F( "MVA_LikelihoodD", "MVA_LikelihoodD", nbin,xmin,xmax,nbin,ymin,ymax);
00215 if (Use["LikelihoodPCA"]) histLkPCA = new TH2F( "MVA_LikelihoodPCA", "MVA_LikelihoodPCA", nbin,xmin,xmax,nbin,ymin,ymax);
00216 if (Use["LikelihoodKDE"]) histLkKDE = new TH2F( "MVA_LikelihoodKDE", "MVA_LikelihoodKDE", nbin,xmin,xmax,nbin,ymin,ymax);
00217 if (Use["LikelihoodMIX"]) histLkMIX = new TH2F( "MVA_LikelihoodMIX", "MVA_LikelihoodMIX", nbin,xmin,xmax,nbin,ymin,ymax);
00218 if (Use["PDERS"]) histPD = new TH2F( "MVA_PDERS", "MVA_PDERS", nbin,xmin,xmax,nbin,ymin,ymax);
00219 if (Use["PDERSD"]) histPDD = new TH2F( "MVA_PDERSD", "MVA_PDERSD", nbin,xmin,xmax,nbin,ymin,ymax);
00220 if (Use["PDERSPCA"]) histPDPCA = new TH2F( "MVA_PDERSPCA", "MVA_PDERSPCA", nbin,xmin,xmax,nbin,ymin,ymax);
00221 if (Use["KNN"]) histKNN = new TH2F( "MVA_KNN", "MVA_KNN", nbin,xmin,xmax,nbin,ymin,ymax);
00222 if (Use["HMatrix"]) histHm = new TH2F( "MVA_HMatrix", "MVA_HMatrix", nbin,xmin,xmax,nbin,ymin,ymax);
00223 if (Use["Fisher"]) histFi = new TH2F( "MVA_Fisher", "MVA_Fisher", nbin,xmin,xmax,nbin,ymin,ymax);
00224 if (Use["FisherG"]) histFiG = new TH2F( "MVA_FisherG", "MVA_FisherG", nbin,xmin,xmax,nbin,ymin,ymax);
00225 if (Use["BoostedFisher"]) histFiB = new TH2F( "MVA_BoostedFisher", "MVA_BoostedFisher", nbin,xmin,xmax,nbin,ymin,ymax);
00226 if (Use["LD"]) histLD = new TH2F( "MVA_LD", "MVA_LD", nbin,xmin,xmax,nbin,ymin,ymax);
00227 if (Use["MLP"]) histNn = new TH2F( "MVA_MLP", "MVA_MLP", nbin,xmin,xmax,nbin,ymin,ymax);
00228 if (Use["CFMlpANN"]) histNnC = new TH2F( "MVA_CFMlpANN", "MVA_CFMlpANN", nbin,xmin,xmax,nbin,ymin,ymax);
00229 if (Use["TMlpANN"]) histNnT = new TH2F( "MVA_TMlpANN", "MVA_TMlpANN", nbin,xmin,xmax,nbin,ymin,ymax);
00230 if (Use["BDT"]) histBdt = new TH2F( "MVA_BDT", "MVA_BDT", nbin,xmin,xmax,nbin,ymin,ymax);
00231 if (Use["BDTD"]) histBdtD = new TH2F( "MVA_BDTD", "MVA_BDTD", nbin,xmin,xmax,nbin,ymin,ymax);
00232 if (Use["BDTG"]) histBdtG = new TH2F( "MVA_BDTG", "MVA_BDTG", nbin,xmin,xmax,nbin,ymin,ymax);
00233 if (Use["RuleFit"]) histRf = new TH2F( "MVA_RuleFit", "MVA_RuleFit", nbin,xmin,xmax,nbin,ymin,ymax);
00234 if (Use["SVM_Gauss"]) histSVMG = new TH2F( "MVA_SVM_Gauss", "MVA_SVM_Gauss", nbin,xmin,xmax,nbin,ymin,ymax);
00235 if (Use["SVM_Poly"]) histSVMP = new TH2F( "MVA_SVM_Poly", "MVA_SVM_Poly", nbin,xmin,xmax,nbin,ymin,ymax);
00236 if (Use["SVM_Lin"]) histSVML = new TH2F( "MVA_SVM_Lin", "MVA_SVM_Lin", nbin,xmin,xmax,nbin,ymin,ymax);
00237 if (Use["FDA_MT"]) histFDAMT = new TH2F( "MVA_FDA_MT", "MVA_FDA_MT", nbin,xmin,xmax,nbin,ymin,ymax);
00238 if (Use["FDA_GA"]) histFDAGA = new TH2F( "MVA_FDA_GA", "MVA_FDA_GA", nbin,xmin,xmax,nbin,ymin,ymax);
00239 if (Use["Category"]) histCat = new TH2F( "MVA_Category", "MVA_Category", nbin,xmin,xmax,nbin,ymin,ymax);
00240 if (Use["Plugin"]) histPBdt = new TH2F( "MVA_PBDT", "MVA_BDT", nbin,xmin,xmax,nbin,ymin,ymax);
00241
00242
00243
00244
00245
00246
00247
00248
00249
00250 for (Int_t ibin=1; ibin<nbin+1; ibin++){
00251 for (Int_t jbin=1; jbin<nbin+1; jbin++){
00252 var0 = hs->GetXaxis()->GetBinCenter(ibin);
00253 var1 = hs->GetYaxis()->GetBinCenter(jbin);
00254
00255
00256 if (Use["Likelihood" ]) histLk ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "Likelihood method" ) );
00257 if (Use["LikelihoodD" ]) histLkD ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "LikelihoodD method" ) );
00258 if (Use["LikelihoodPCA"]) histLkPCA ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "LikelihoodPCA method" ) );
00259 if (Use["LikelihoodKDE"]) histLkKDE ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "LikelihoodKDE method" ) );
00260 if (Use["LikelihoodMIX"]) histLkMIX ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "LikelihoodMIX method" ) );
00261 if (Use["PDERS" ]) histPD ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "PDERS method" ) );
00262 if (Use["PDERSD" ]) histPDD ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "PDERSD method" ) );
00263 if (Use["PDERSPCA" ]) histPDPCA ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "PDERSPCA method" ) );
00264 if (Use["KNN" ]) histKNN ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "KNN method" ) );
00265 if (Use["HMatrix" ]) histHm ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "HMatrix method" ) );
00266 if (Use["Fisher" ]) histFi ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "Fisher method" ) );
00267 if (Use["FisherG" ]) histFiG ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "FisherG method" ) );
00268 if (Use["BoostedFisher"]) histFiB ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "BoostedFisher method" ) );
00269 if (Use["LD" ]) histLD ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "LD method" ) );
00270 if (Use["MLP" ]) histNn ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "MLP method" ) );
00271 if (Use["CFMlpANN" ]) histNnC ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "CFMlpANN method" ) );
00272 if (Use["TMlpANN" ]) histNnT ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "TMlpANN method" ) );
00273 if (Use["BDT" ]) histBdt ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "BDT method" ) );
00274 if (Use["BDTD" ]) histBdtD ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "BDTD method" ) );
00275 if (Use["BDTG" ]) histBdtG ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "BDTG method" ) );
00276 if (Use["RuleFit" ]) histRf ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "RuleFit method" ) );
00277 if (Use["SVM_Gauss" ]) histSVMG ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "SVM_Gauss method" ) );
00278 if (Use["SVM_Poly" ]) histSVMP ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "SVM_Poly method" ) );
00279 if (Use["SVM_Lin" ]) histSVML ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "SVM_Lin method" ) );
00280 if (Use["FDA_MT" ]) histFDAMT ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "FDA_MT method" ) );
00281 if (Use["FDA_GA" ]) histFDAGA ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "FDA_GA method" ) );
00282 if (Use["Category" ]) histCat ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "Category method" ) );
00283 if (Use["Plugin" ]) histPBdt ->SetBinContent(ibin,jbin, reader->EvaluateMVA( "P_BDT method" ) );
00284 }
00285 }
00286
00287
00288 std::cout << "--- Created root file: \"TMVApp.root\" containing the MVA output histograms" << std::endl;
00289
00290 delete reader;
00291
00292
00293
00294
00295 std::cout << "==> TMVAClassificationApplication is done!" << endl << std::endl;
00296
00297
00298
00299 gStyle->SetPalette(1);
00300
00301 if (Use["Likelihood" ]) plot(hs,hb,histLk ,v0,v1);
00302 if (Use["LikelihoodD" ]) plot(hs,hb,histLkD ,v0,v1);
00303 if (Use["LikelihoodPCA"]) plot(hs,hb,histLkPCA ,v0,v1);
00304 if (Use["LikelihoodKDE"]) plot(hs,hb,histLkKDE ,v0,v1);
00305 if (Use["LikelihoodMIX"]) plot(hs,hb,histLkMIX ,v0,v1);
00306 if (Use["PDERS" ]) plot(hs,hb,histPD ,v0,v1);
00307 if (Use["PDERSD" ]) plot(hs,hb,histPDD ,v0,v1);
00308 if (Use["PDERSPCA" ]) plot(hs,hb,histPDPCA ,v0,v1);
00309 if (Use["KNN" ]) plot(hs,hb,histKNN ,v0,v1);
00310 if (Use["HMatrix" ]) plot(hs,hb,histHm ,v0,v1);
00311 if (Use["Fisher" ]) plot(hs,hb,histFi ,v0,v1);
00312 if (Use["FisherG" ]) plot(hs,hb,histFiG ,v0,v1);
00313 if (Use["BoostedFisher"]) plot(hs,hb,histFiB ,v0,v1);
00314 if (Use["LD" ]) plot(hs,hb,histLD ,v0,v1);
00315 if (Use["MLP" ]) plot(hs,hb,histNn ,v0,v1);
00316 if (Use["CFMlpANN" ]) plot(hs,hb,histNnC ,v0,v1);
00317 if (Use["TMlpANN" ]) plot(hs,hb,histNnT ,v0,v1);
00318 if (Use["BDT" ]) plot(hs,hb,histBdt ,v0,v1);
00319 if (Use["BDTD" ]) plot(hs,hb,histBdtD ,v0,v1);
00320 if (Use["BDTG" ]) plot(hs,hb,histBdtG ,v0,v1);
00321 if (Use["RuleFit" ]) plot(hs,hb,histRf ,v0,v1);
00322 if (Use["SVM_Gauss" ]) plot(hs,hb,histSVMG ,v0,v1);
00323 if (Use["SVM_Poly" ]) plot(hs,hb,histSVMP ,v0,v1);
00324 if (Use["SVM_Lin" ]) plot(hs,hb,histSVML ,v0,v1);
00325 if (Use["FDA_MT" ]) plot(hs,hb,histFDAMT ,v0,v1);
00326 if (Use["FDA_GA" ]) plot(hs,hb,histFDAGA ,v0,v1);
00327 if (Use["Category" ]) plot(hs,hb,histCat ,v0,v1);
00328 if (Use["Plugin" ]) plot(hs,hb,histPBdt ,v0,v1);
00329
00330
00331
00332
00333
00334 TFile *target = new TFile( "TMVApp.root","RECREATE" );
00335
00336 hs->Write();
00337 hb->Write();
00338
00339 if (Use["Likelihood" ]) histLk ->Write();
00340 if (Use["LikelihoodD" ]) histLkD ->Write();
00341 if (Use["LikelihoodPCA"]) histLkPCA ->Write();
00342 if (Use["LikelihoodKDE"]) histLkKDE ->Write();
00343 if (Use["LikelihoodMIX"]) histLkMIX ->Write();
00344 if (Use["PDERS" ]) histPD ->Write();
00345 if (Use["PDERSD" ]) histPDD ->Write();
00346 if (Use["PDERSPCA" ]) histPDPCA ->Write();
00347 if (Use["KNN" ]) histKNN ->Write();
00348 if (Use["HMatrix" ]) histHm ->Write();
00349 if (Use["Fisher" ]) histFi ->Write();
00350 if (Use["FisherG" ]) histFiG ->Write();
00351 if (Use["BoostedFisher"]) histFiB ->Write();
00352 if (Use["LD" ]) histLD ->Write();
00353 if (Use["MLP" ]) histNn ->Write();
00354 if (Use["CFMlpANN" ]) histNnC ->Write();
00355 if (Use["TMlpANN" ]) histNnT ->Write();
00356 if (Use["BDT" ]) histBdt ->Write();
00357 if (Use["BDTD" ]) histBdtD ->Write();
00358 if (Use["BDTG" ]) histBdtG ->Write();
00359 if (Use["RuleFit" ]) histRf ->Write();
00360 if (Use["SVM_Gauss" ]) histSVMG ->Write();
00361 if (Use["SVM_Poly" ]) histSVMP ->Write();
00362 if (Use["SVM_Lin" ]) histSVML ->Write();
00363 if (Use["FDA_MT" ]) histFDAMT ->Write();
00364 if (Use["FDA_GA" ]) histFDAGA ->Write();
00365 if (Use["Category" ]) histCat ->Write();
00366 if (Use["Plugin" ]) histPBdt ->Write();
00367
00368 target->Close();
00369
00370 }
00371