00001 #include <iostream>
00002 #include <iomanip>
00003 #include <fstream>
00004
00005 #include "tmvaglob.C"
00006
00007 #include "RQ_OBJECT.h"
00008
00009 #include "TROOT.h"
00010 #include "TStyle.h"
00011 #include "TPad.h"
00012 #include "TCanvas.h"
00013 #include "TLine.h"
00014 #include "TFile.h"
00015 #include "TColor.h"
00016 #include "TPaveText.h"
00017 #include "TObjString.h"
00018 #include "TControlBar.h"
00019
00020 #include "TGWindow.h"
00021 #include "TGButton.h"
00022 #include "TGLabel.h"
00023 #include "TGNumberEntry.h"
00024
00025 #include "TMVA/DecisionTree.h"
00026 #include "TMVA/Tools.h"
00027 #include "TXMLEngine.h"
00028
00029
00030
00031
00032
00033
00034
00035
00036 static const Int_t kSigColorF = TColor::GetColor( "#2244a5" );
00037 static const Int_t kBkgColorF = TColor::GetColor( "#dd0033" );
00038 static const Int_t kIntColorF = TColor::GetColor( "#33aa77" );
00039
00040 static const Int_t kSigColorT = 10;
00041 static const Int_t kBkgColorT = 10;
00042 static const Int_t kIntColorT = 10;
00043
00044 enum PlotType { EffPurity = 0 };
00045
00046 class StatDialogBDT {
00047
00048 RQ_OBJECT("StatDialogBDT")
00049
00050 public:
00051
00052 StatDialogBDT( const TGWindow* p, TString wfile = "weights/TMVAClassification_BDT.weights.txt",
00053 TString methName = "BDT", Int_t itree = 0 );
00054 virtual ~StatDialogBDT() {
00055 TMVA::DecisionTreeNode::fgIsTraining=false;
00056 fThis = 0;
00057 fMain->CloseWindow();
00058 fMain->Cleanup();
00059 if(gROOT->GetListOfCanvases()->FindObject(fCanvas))
00060 delete fCanvas;
00061 }
00062
00063
00064 void DrawTree( Int_t itree );
00065
00066 void RaiseDialog() { if (fMain) { fMain->RaiseWindow(); fMain->Layout(); fMain->MapWindow(); } }
00067
00068 private:
00069
00070 TGMainFrame *fMain;
00071 Int_t fItree;
00072 Int_t fNtrees;
00073 TCanvas* fCanvas;
00074
00075 TGNumberEntry* fInput;
00076
00077 TGHorizontalFrame* fButtons;
00078 TGTextButton* fDrawButton;
00079 TGTextButton* fCloseButton;
00080
00081 void UpdateCanvases();
00082
00083
00084 TMVA::DecisionTree* ReadTree( TString * &vars, Int_t itree );
00085 void DrawNode( TMVA::DecisionTreeNode *n,
00086 Double_t x, Double_t y, Double_t xscale, Double_t yscale, TString* vars );
00087 void GetNtrees();
00088
00089 TString fWfile;
00090 TString fMethName;
00091
00092 public:
00093
00094
00095 static void Delete() { if (fThis != 0) { delete fThis; fThis = 0; } }
00096
00097
00098 void SetItree();
00099 void Redraw();
00100 void Close();
00101
00102 private:
00103
00104 static StatDialogBDT* fThis;
00105
00106 };
00107
00108 StatDialogBDT* StatDialogBDT::fThis = 0;
00109
00110 void StatDialogBDT::SetItree()
00111 {
00112 fItree = Int_t(fInput->GetNumber());
00113 }
00114
00115 void StatDialogBDT::Redraw()
00116 {
00117 UpdateCanvases();
00118 }
00119
00120 void StatDialogBDT::Close()
00121 {
00122 delete this;
00123 }
00124
00125 StatDialogBDT::StatDialogBDT( const TGWindow* p, TString wfile, TString methName, Int_t itree )
00126 : fMain( 0 ),
00127 fItree(itree),
00128 fNtrees(0),
00129 fCanvas(0),
00130 fInput(0),
00131 fButtons(0),
00132 fDrawButton(0),
00133 fCloseButton(0),
00134 fWfile( wfile ),
00135 fMethName( methName )
00136 {
00137 UInt_t totalWidth = 500;
00138 UInt_t totalHeight = 200;
00139
00140 fThis = this;
00141
00142 TMVA::DecisionTreeNode::fgIsTraining=true;
00143
00144
00145 GetNtrees();
00146
00147
00148 fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
00149
00150 TGLabel *sigLab = new TGLabel( fMain, Form( "Decision tree [%i-%i]",0,fNtrees-1 ) );
00151 fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00152
00153 fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
00154 fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00155 fInput->Resize(100,24);
00156 fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
00157
00158 fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
00159
00160 fCloseButton = new TGTextButton(fButtons,"&Close");
00161 fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
00162
00163 fDrawButton = new TGTextButton(fButtons,"&Draw");
00164 fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
00165
00166 fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
00167
00168 fMain->SetWindowName("Decision tree");
00169 fMain->SetWMPosition(0,0);
00170 fMain->MapSubwindows();
00171 fMain->Resize(fMain->GetDefaultSize());
00172 fMain->MapWindow();
00173
00174 fInput->Connect("ValueSet(Long_t)","StatDialogBDT",this, "SetItree()");
00175
00176 fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
00177 fDrawButton->Connect("Clicked()", "StatDialogBDT", this, "Redraw()");
00178
00179 fCloseButton->Connect("Clicked()", "StatDialogBDT", this, "Close()");
00180 }
00181
00182 void StatDialogBDT::UpdateCanvases()
00183 {
00184 DrawTree( fItree );
00185 }
00186
00187 void StatDialogBDT::GetNtrees()
00188 {
00189 if(!fWfile.EndsWith(".xml") ){
00190 ifstream fin( fWfile );
00191 if (!fin.good( )) {
00192 cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00193 return;
00194 }
00195
00196 TString dummy = "";
00197
00198
00199 Int_t nc = 0;
00200 while (!dummy.Contains("NTrees")) {
00201 fin >> dummy;
00202 nc++;
00203 if (nc > 200) {
00204 cout << endl;
00205 cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
00206 << fWfile << endl;
00207 cout << "==> panic abort (please contact the TMVA authors)" << endl;
00208 cout << endl;
00209 exit(1);
00210 }
00211 }
00212 fin >> dummy;
00213 fNtrees = dummy.ReplaceAll("\"","").Atoi();
00214 fin.close();
00215 }
00216 else{
00217 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00218 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00219 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00220 while(ch){
00221 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00222 if(nodeName=="Weights") {
00223 TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
00224 break;
00225 }
00226 ch = TMVA::gTools().xmlengine().GetNext(ch);
00227 }
00228 }
00229 cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
00230
00231 }
00232
00233
00234 void StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n,
00235 Double_t x, Double_t y,
00236 Double_t xscale, Double_t yscale, TString * vars)
00237 {
00238
00239
00240 Float_t xsize=xscale*1.5;
00241 Float_t ysize=yscale/3;
00242 if (xsize>0.15) xsize=0.1;
00243 if (n->GetLeft() != NULL){
00244 TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
00245 a1->SetLineWidth(2);
00246 a1->Draw();
00247 DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
00248 }
00249 if (n->GetRight() != NULL){
00250 TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
00251 a1->SetLineWidth(2);
00252 a1->Draw();
00253 DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
00254 }
00255
00256
00257 TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
00258
00259 t->SetBorderSize(1);
00260
00261 t->SetFillStyle(1);
00262 if (n->GetNodeType() == 1) { t->SetFillColor( kSigColorF ); t->SetTextColor( kSigColorT ); }
00263 else if (n->GetNodeType() == -1) { t->SetFillColor( kBkgColorF ); t->SetTextColor( kBkgColorT ); }
00264 else if (n->GetNodeType() == 0) { t->SetFillColor( kIntColorF ); t->SetTextColor( kIntColorT ); }
00265
00266 char buffer[25];
00267 sprintf( buffer, "N=%f", n->GetNEvents() );
00268 if (n->GetNEvents()>0) t->AddText(buffer);
00269 sprintf( buffer, "S/(S+B)=%4.3f", n->GetPurity() );
00270 t->AddText(buffer);
00271
00272 if (n->GetNodeType() == 0){
00273 if (n->GetCutType()){
00274 t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
00275 }else{
00276 t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
00277 }
00278 }
00279
00280 t->Draw();
00281
00282 return;
00283 }
00284 TMVA::DecisionTree* StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
00285 {
00286 cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
00287 TMVA::DecisionTree *d = new TMVA::DecisionTree();
00288 if(!fWfile.EndsWith(".xml") ){
00289 ifstream fin( fWfile );
00290 if (!fin.good( )) {
00291 cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00292 return 0;
00293 }
00294
00295 TString dummy = "";
00296
00297 if (itree >= fNtrees) {
00298 cout << "*** ERROR: requested decision tree: " << itree
00299 << ", but number of trained trees only: " << fNtrees << endl;
00300 return 0;
00301 }
00302
00303
00304 while (!dummy.Contains("#VAR")) fin >> dummy;
00305 fin >> dummy >> dummy >> dummy;
00306
00307
00308 Int_t nVars;
00309 fin >> dummy >> nVars;
00310
00311
00312 vars = new TString[nVars+1];
00313 for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
00314 vars[nVars]="FisherCrit";
00315
00316 char buffer[20];
00317 char line[256];
00318 sprintf(buffer,"Tree %d",itree);
00319
00320 while (!dummy.Contains(buffer)) {
00321 fin.getline(line,256);
00322 dummy = TString(line);
00323 }
00324
00325 d->Read(fin);
00326
00327 fin.close();
00328 }
00329 else{
00330 if (itree >= fNtrees) {
00331 cout << "*** ERROR: requested decision tree: " << itree
00332 << ", but number of trained trees only: " << fNtrees << endl;
00333 return 0;
00334 }
00335 Int_t nVars;
00336 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00337 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00338 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00339 while(ch){
00340 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00341 if(nodeName=="Variables"){
00342 TMVA::gTools().ReadAttr( ch, "NVar", nVars);
00343 vars = new TString[nVars+1];
00344 void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
00345 for (Int_t i = 0; i < nVars; i++){
00346 TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
00347 varnode = TMVA::gTools().xmlengine().GetNext(varnode);
00348 }
00349 vars[nVars]="FisherCrit";
00350 }
00351 if(nodeName=="Weights") break;
00352 ch = TMVA::gTools().xmlengine().GetNext(ch);
00353 }
00354 ch = TMVA::gTools().xmlengine().GetChild(ch);
00355 for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
00356 d->ReadXML(ch);
00357 }
00358 return d;
00359 }
00360
00361
00362 void StatDialogBDT::DrawTree( Int_t itree )
00363 {
00364 TString *vars;
00365 TMVA::DecisionTree* d = ReadTree( vars, itree );
00366 if (d == 0) return;
00367
00368 UInt_t depth = d->GetTotalTreeDepth();
00369 Double_t ystep = 1.0/(depth + 1.0);
00370
00371 cout << "--- Tree depth: " << depth << endl;
00372
00373 TStyle* TMVAStyle = gROOT->GetStyle("Plain");
00374 Int_t canvasColor = TMVAStyle->GetCanvasColor();
00375
00376 TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
00377 TString tbuffer = Form( "Decision Tree no.: %d", itree );
00378 if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
00379 else fCanvas->Clear();
00380 fCanvas->Draw();
00381
00382 DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
00383
00384
00385 Double_t yup=0.99;
00386 Double_t ydown=yup-ystep/2.5;
00387 Double_t dy= ystep/2.5 * 0.2;
00388
00389 TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
00390 whichTree->SetBorderSize(1);
00391 whichTree->SetFillStyle(1);
00392 whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
00393 whichTree->AddText( tbuffer );
00394 whichTree->Draw();
00395
00396 TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00397 intermediate->SetBorderSize(1);
00398 intermediate->SetFillStyle(1);
00399 intermediate->SetFillColor( kIntColorF );
00400 intermediate->AddText("Intermediate Nodes");
00401 intermediate->SetTextColor( kIntColorT );
00402 intermediate->Draw();
00403
00404 ydown = ydown - ystep/2.5 -dy;
00405 yup = yup - ystep/2.5 -dy;
00406 TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
00407 signalleaf->SetBorderSize(1);
00408 signalleaf->SetFillStyle(1);
00409 signalleaf->SetFillColor( kSigColorF );
00410 signalleaf->AddText("Signal Leaf Nodes");
00411 signalleaf->SetTextColor( kSigColorT );
00412 signalleaf->Draw();
00413
00414 ydown = ydown - ystep/2.5 -dy;
00415 yup = yup - ystep/2.5 -dy;
00416 TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00417 backgroundleaf->SetBorderSize(1);
00418 backgroundleaf->SetFillStyle(1);
00419 backgroundleaf->SetFillColor( kBkgColorF );
00420
00421 backgroundleaf->AddText("Backgr. Leaf Nodes");
00422 backgroundleaf->SetTextColor( kBkgColorT );
00423 backgroundleaf->Draw();
00424
00425 fCanvas->Update();
00426 TString fname = Form("plots/%s_%i", fMethName.Data(), itree );
00427 cout << "--- Creating image: " << fname << endl;
00428 TMVAGlob::imgconv( fCanvas, fname );
00429
00430 TMVAStyle->SetCanvasColor( canvasColor );
00431 }
00432
00433
00434
00435 static std::vector<TControlBar*> BDT_Global__cbar;
00436
00437
00438 void BDT( const TString& fin = "TMVA.root" )
00439 {
00440
00441
00442
00443 TMVAGlob::DestroyCanvases();
00444
00445
00446 TFile* file = TMVAGlob::OpenFile( fin );
00447
00448 TDirectory* dir = file->GetDirectory( "Method_BDT" );
00449 if (!dir) {
00450 cout << "*** Error in macro \"BDT.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
00451 return;
00452 }
00453
00454
00455 TIter next( dir->GetListOfKeys() );
00456 TKey *key(0);
00457 std::vector<TString> methname;
00458 std::vector<TString> path;
00459 std::vector<TString> wfile;
00460 while ((key = (TKey*)next())) {
00461 TDirectory* mdir = dir->GetDirectory( key->GetName() );
00462 if (!mdir) {
00463 cout << "*** Error in macro \"BDT.C\": cannot find sub-directory: " << key->GetName()
00464 << " in directory: " << dir->GetName() << endl;
00465 return;
00466 }
00467
00468
00469 TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
00470 TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
00471 if (!strPath || !strWFile) {
00472 cout << "*** Error in macro \"BDT.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
00473 cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
00474 return;
00475 }
00476
00477 methname.push_back( key->GetName() );
00478 path .push_back( strPath->GetString() );
00479 wfile .push_back( strWFile->GetString() );
00480 }
00481
00482
00483 TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
00484 BDT_Global__cbar.push_back(cbar);
00485
00486 for (UInt_t im=0; im<path.size(); im++) {
00487 TString fname = path[im];
00488 if (fname[fname.Length()-1] != '/') fname += "/";
00489 fname += wfile[im];
00490 TString macro = Form( ".x BDT.C+\(0,\"%s\",\"%s\")", fname.Data(), methname[im].Data() );
00491 cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
00492 }
00493
00494
00495 #if ROOT_VERSION_CODE < ROOT_VERSION(5,19,0)
00496 cbar->AddButton( "Close", Form("BDT_DeleteTBar(%i)", BDT_Global__cbar.size()-1), "Close this control bar", "button" );
00497 #endif
00498
00499
00500
00501 cbar->SetTextColor("blue");
00502
00503
00504 cbar->Show();
00505 }
00506
00507 void BDT_DeleteTBar(int i)
00508 {
00509
00510 StatDialogBDT::Delete();
00511 TMVAGlob::DestroyCanvases();
00512
00513 delete BDT_Global__cbar[i];
00514 BDT_Global__cbar[i] = 0;
00515 }
00516
00517
00518
00519 void BDT( Int_t itree, TString wfile = "weights/TMVAnalysis_test_BDT.weights.txt", TString methName = "BDT", Bool_t useTMVAStyle = kTRUE )
00520 {
00521
00522 StatDialogBDT::Delete();
00523 TMVAGlob::DestroyCanvases();
00524
00525
00526 if(!wfile.EndsWith(".xml") ){
00527 ifstream fin( wfile );
00528 if (!fin.good( )) {
00529 cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
00530 return;
00531 }
00532 }
00533 std::cout << "test1";
00534
00535 TMVAGlob::Initialize( useTMVAStyle );
00536
00537 StatDialogBDT* gGui = new StatDialogBDT( gClient->GetRoot(), wfile, methName, itree );
00538
00539 gGui->DrawTree( itree );
00540
00541 gGui->RaiseDialog();
00542 }
00543