00001 #include <iostream>
00002 #include <iomanip>
00003 #include <fstream>
00004
00005 #include "tmvaglob.C"
00006
00007 #include "RQ_OBJECT.h"
00008
00009 #include "TROOT.h"
00010 #include "TStyle.h"
00011 #include "TPad.h"
00012 #include "TCanvas.h"
00013 #include "TLine.h"
00014 #include "TFile.h"
00015 #include "TColor.h"
00016 #include "TPaveText.h"
00017 #include "TObjString.h"
00018 #include "TControlBar.h"
00019
00020 #include "TGWindow.h"
00021 #include "TGButton.h"
00022 #include "TGLabel.h"
00023 #include "TGNumberEntry.h"
00024
00025 #include "TMVA/DecisionTree.h"
00026 #include "TMVA/Tools.h"
00027 #include "TXMLEngine.h"
00028
00029
00030
00031
00032
00033
00034
00035
00036 static const Int_t kSigColorF = TColor::GetColor( "#2244a5" );
00037 static const Int_t kBkgColorF = TColor::GetColor( "#dd0033" );
00038 static const Int_t kIntColorF = TColor::GetColor( "#33aa77" );
00039
00040 static const Int_t kSigColorT = 10;
00041 static const Int_t kBkgColorT = 10;
00042 static const Int_t kIntColorT = 10;
00043
00044 enum PlotType { EffPurity = 0 };
00045
00046 class StatDialogBDT {
00047
00048 RQ_OBJECT("StatDialogBDT")
00049
00050 public:
00051
00052 StatDialogBDT( const TGWindow* p, TString wfile = "weights/TMVARegression_BDT.weights.xml",
00053 TString methName = "BDT", Int_t itree = 0 );
00054 virtual ~StatDialogBDT() {
00055 TMVA::DecisionTreeNode::fgIsTraining=false;
00056 fThis = 0;
00057 fMain->CloseWindow();
00058 fMain->Cleanup();
00059 if(gROOT->GetListOfCanvases()->FindObject(fCanvas))
00060 delete fCanvas;
00061 }
00062
00063
00064 void DrawTree( Int_t itree );
00065
00066 void RaiseDialog() { if (fMain) { fMain->RaiseWindow(); fMain->Layout(); fMain->MapWindow(); } }
00067
00068 private:
00069
00070 TGMainFrame *fMain;
00071 Int_t fItree;
00072 Int_t fNtrees;
00073 TCanvas* fCanvas;
00074
00075 TGNumberEntry* fInput;
00076
00077 TGHorizontalFrame* fButtons;
00078 TGTextButton* fDrawButton;
00079 TGTextButton* fCloseButton;
00080
00081 void UpdateCanvases();
00082
00083
00084 TMVA::DecisionTree* ReadTree( TString * &vars, Int_t itree );
00085 void DrawNode( TMVA::DecisionTreeNode *n,
00086 Double_t x, Double_t y, Double_t xscale, Double_t yscale, TString* vars );
00087 void GetNtrees();
00088
00089 TString fWfile;
00090 TString fMethName;
00091
00092 public:
00093
00094
00095 static void Delete() { if (fThis != 0) { delete fThis; fThis = 0; } }
00096
00097
00098 void SetItree();
00099 void Redraw();
00100 void Close();
00101
00102 private:
00103
00104 static StatDialogBDT* fThis;
00105
00106 };
00107
00108 StatDialogBDT* StatDialogBDT::fThis = 0;
00109
00110 void StatDialogBDT::SetItree()
00111 {
00112 fItree = Int_t(fInput->GetNumber());
00113 }
00114
00115 void StatDialogBDT::Redraw()
00116 {
00117 UpdateCanvases();
00118 }
00119
00120 void StatDialogBDT::Close()
00121 {
00122 delete this;
00123 }
00124
00125 StatDialogBDT::StatDialogBDT( const TGWindow* p, TString wfile, TString methName, Int_t itree )
00126 : fMain( 0 ),
00127 fItree(itree),
00128 fNtrees(0),
00129 fCanvas(0),
00130 fInput(0),
00131 fButtons(0),
00132 fDrawButton(0),
00133 fCloseButton(0),
00134 fWfile( wfile ),
00135 fMethName( methName )
00136 {
00137 UInt_t totalWidth = 500;
00138 UInt_t totalHeight = 200;
00139
00140 fThis = this;
00141
00142
00143 GetNtrees();
00144
00145
00146 fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
00147
00148 TGLabel *sigLab = new TGLabel( fMain, Form( "Regression tree [%i-%i]",0,fNtrees-1 ) );
00149 fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00150
00151 fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
00152 fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00153 fInput->Resize(100,24);
00154 fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
00155
00156 fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
00157
00158 fCloseButton = new TGTextButton(fButtons,"&Close");
00159 fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
00160
00161 fDrawButton = new TGTextButton(fButtons,"&Draw");
00162 fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
00163
00164 fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
00165
00166 fMain->SetWindowName("Regression tree");
00167 fMain->SetWMPosition(0,0);
00168 fMain->MapSubwindows();
00169 fMain->Resize(fMain->GetDefaultSize());
00170 fMain->MapWindow();
00171
00172 fInput->Connect("ValueSet(Long_t)","StatDialogBDT",this, "SetItree()");
00173
00174 fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
00175 fDrawButton->Connect("Clicked()", "StatDialogBDT", this, "Redraw()");
00176
00177 fCloseButton->Connect("Clicked()", "StatDialogBDT", this, "Close()");
00178 }
00179
00180 void StatDialogBDT::UpdateCanvases()
00181 {
00182 DrawTree( fItree );
00183 }
00184
00185 void StatDialogBDT::GetNtrees()
00186 {
00187 if(!fWfile.EndsWith(".xml") ){
00188 ifstream fin( fWfile );
00189 if (!fin.good( )) {
00190 cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00191 return;
00192 }
00193
00194 TString dummy = "";
00195
00196
00197 Int_t nc = 0;
00198 while (!dummy.Contains("NTrees")) {
00199 fin >> dummy;
00200 nc++;
00201 if (nc > 200) {
00202 cout << endl;
00203 cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: "
00204 << fWfile << endl;
00205 cout << "==> panic abort (please contact the TMVA authors)" << endl;
00206 cout << endl;
00207 exit(1);
00208 }
00209 }
00210 fin >> dummy;
00211 fNtrees = dummy.ReplaceAll("\"","").Atoi();
00212 fin.close();
00213 }
00214 else{
00215 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00216 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00217 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00218 while(ch){
00219 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00220 if(nodeName=="Weights") {
00221 TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
00222 break;
00223 }
00224 ch = TMVA::gTools().xmlengine().GetNext(ch);
00225 }
00226 }
00227 cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
00228
00229 }
00230
00231
00232 void StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n,
00233 Double_t x, Double_t y,
00234 Double_t xscale, Double_t yscale, TString * vars)
00235 {
00236
00237
00238 Float_t xsize=xscale*1.5;
00239 Float_t ysize=yscale/3;
00240 if (xsize>0.15) xsize=0.1;
00241 if (n->GetLeft() != NULL){
00242 TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
00243 a1->SetLineWidth(2);
00244 a1->Draw();
00245 DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
00246 }
00247 if (n->GetRight() != NULL){
00248 TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
00249 a1->SetLineWidth(2);
00250 a1->Draw();
00251 DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars );
00252 }
00253
00254
00255 TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
00256
00257 t->SetBorderSize(1);
00258
00259 t->SetFillStyle(1);
00260 if (n->GetNodeType() == 1) { t->SetFillColor( kSigColorF ); t->SetTextColor( kSigColorT ); }
00261 else if (n->GetNodeType() == -1) { t->SetFillColor( kBkgColorF ); t->SetTextColor( kBkgColorT ); }
00262 else if (n->GetNodeType() == 0) { t->SetFillColor( kIntColorF ); t->SetTextColor( kIntColorT ); }
00263
00264 char buffer[25];
00265
00266
00267 sprintf( buffer, "R=%4.1f +- %4.1f", n->GetResponse(),n->GetRMS() );
00268 t->AddText(buffer);
00269
00270 if (n->GetNodeType() == 0){
00271 if (n->GetCutType()){
00272 t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
00273 }else{
00274 t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
00275 }
00276 }
00277
00278 t->Draw();
00279
00280 return;
00281 }
00282
00283 TMVA::DecisionTree* StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
00284 {
00285 cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
00286 TMVA::DecisionTree *d = new TMVA::DecisionTree();
00287
00288
00289 if(!fWfile.EndsWith(".xml") ){
00290
00291 ifstream fin( fWfile );
00292 if (!fin.good( )) {
00293 cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00294 return 0;
00295 }
00296 TString dummy = "";
00297
00298 if (itree >= fNtrees) {
00299 cout << "*** ERROR: requested decision tree: " << itree
00300 << ", but number of trained trees only: " << fNtrees << endl;
00301 return 0;
00302 }
00303
00304
00305 while (!dummy.Contains("#VAR")) fin >> dummy;
00306 fin >> dummy >> dummy >> dummy;
00307
00308
00309 Int_t nVars;
00310 fin >> dummy >> nVars;
00311
00312
00313 vars = new TString[nVars+1];
00314 for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
00315 vars[nVars]="FisherCrit";
00316
00317 char buffer[20];
00318 char line[256];
00319 sprintf(buffer,"Tree %d",itree);
00320
00321 while (!dummy.Contains(buffer)) {
00322 fin.getline(line,256);
00323 dummy = TString(line);
00324 }
00325
00326 d->Read(fin);
00327
00328 fin.close();
00329 }
00330 else{
00331 if (itree >= fNtrees) {
00332 cout << "*** ERROR: requested decision tree: " << itree
00333 << ", but number of trained trees only: " << fNtrees << endl;
00334 return 0;
00335 }
00336 Int_t nVars;
00337 void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00338 void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00339 void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00340 while(ch){
00341 TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00342 if(nodeName=="Variables"){
00343 TMVA::gTools().ReadAttr( ch, "NVar", nVars);
00344 vars = new TString[nVars+1];
00345 void* varnode = TMVA::gTools().xmlengine().GetChild(ch);
00346 for (Int_t i = 0; i < nVars; i++){
00347 TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
00348 varnode = TMVA::gTools().xmlengine().GetNext(varnode);
00349 }
00350 vars[nVars]="FisherCrit";
00351 }
00352 if(nodeName=="Weights") break;
00353 ch = TMVA::gTools().xmlengine().GetNext(ch);
00354 }
00355 ch = TMVA::gTools().xmlengine().GetChild(ch);
00356 for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
00357 d->ReadXML(ch);
00358 }
00359 return d;
00360 }
00361
00362
00363 void StatDialogBDT::DrawTree( Int_t itree )
00364 {
00365 TString *vars;
00366
00367 TMVA::DecisionTree* d = ReadTree( vars, itree );
00368 if (d == 0) return;
00369
00370 UInt_t depth = d->GetTotalTreeDepth();
00371 Double_t ystep = 1.0/(depth + 1.0);
00372
00373 cout << "--- Tree depth: " << depth << endl;
00374
00375 TStyle* TMVAStyle = gROOT->GetStyle("Plain");
00376 Int_t canvasColor = TMVAStyle->GetCanvasColor();
00377
00378 TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
00379 TString tbuffer = Form( "Regression Tree no.: %d", itree );
00380 if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 );
00381 else fCanvas->Clear();
00382 fCanvas->Draw();
00383 DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
00384
00385
00386 Double_t yup=0.99;
00387 Double_t ydown=yup-ystep/2.5;
00388 Double_t dy= ystep/2.5 * 0.2;
00389
00390 TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
00391 whichTree->SetBorderSize(1);
00392 whichTree->SetFillStyle(1);
00393 whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
00394 whichTree->AddText( tbuffer );
00395 whichTree->Draw();
00396
00397 TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00398 intermediate->SetBorderSize(1);
00399 intermediate->SetFillStyle(1);
00400 intermediate->SetFillColor( kIntColorF );
00401 intermediate->AddText("Intermediate Nodes");
00402 intermediate->SetTextColor( kIntColorT );
00403 intermediate->Draw();
00404
00405 ydown = ydown - ystep/2.5 -dy;
00406 yup = yup - ystep/2.5 -dy;
00407 TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
00408 signalleaf->SetBorderSize(1);
00409 signalleaf->SetFillStyle(1);
00410 signalleaf->SetFillColor( kSigColorF );
00411 signalleaf->AddText("Leaf Nodes");
00412 signalleaf->SetTextColor( kSigColorT );
00413 signalleaf->Draw();
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426 fCanvas->Update();
00427 TString fname = Form("plots/%s_%i", fMethName.Data(), itree );
00428 cout << "--- Creating image: " << fname << endl;
00429 TMVAGlob::imgconv( fCanvas, fname );
00430
00431 TMVAStyle->SetCanvasColor( canvasColor );
00432 }
00433
00434
00435
00436 static std::vector<TControlBar*> BDT_Global__cbar;
00437
00438
00439 void BDT_Reg( const TString& fin = "TMVAReg.root" )
00440 {
00441
00442
00443
00444 TMVAGlob::DestroyCanvases();
00445
00446
00447 TFile* file = TMVAGlob::OpenFile( fin );
00448
00449 TDirectory* dir = file->GetDirectory( "Method_BDT" );
00450 if (!dir) {
00451 cout << "*** Error in macro \"BDT_Reg.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
00452 return;
00453 }
00454
00455
00456 TIter next( dir->GetListOfKeys() );
00457 TKey *key(0);
00458 std::vector<TString> methname;
00459 std::vector<TString> path;
00460 std::vector<TString> wfile;
00461 while ((key = (TKey*)next())) {
00462 TDirectory* mdir = dir->GetDirectory( key->GetName() );
00463 if (!mdir) {
00464 cout << "*** Error in macro \"BDT_Reg.C\": cannot find sub-directory: " << key->GetName()
00465 << " in directory: " << dir->GetName() << endl;
00466 return;
00467 }
00468
00469
00470 TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
00471 TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
00472 if (!strPath || !strWFile) {
00473 cout << "*** Error in macro \"BDT_Reg.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
00474 cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
00475 return;
00476 }
00477
00478 methname.push_back( key->GetName() );
00479 path .push_back( strPath->GetString() );
00480 wfile .push_back( strWFile->GetString() );
00481 }
00482
00483
00484 TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
00485 BDT_Global__cbar.push_back(cbar);
00486
00487 for (UInt_t im=0; im<path.size(); im++) {
00488 TString fname = path[im];
00489 if (fname[fname.Length()-1] != '/') fname += "/";
00490 fname += wfile[im];
00491 TString macro = Form( ".x BDT_Reg.C+\(0,\"%s\",\"%s\")", fname.Data(), methname[im].Data() );
00492 cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
00493 }
00494
00495
00496 #if ROOT_VERSION_CODE < ROOT_VERSION(5,19,0)
00497 cbar->AddButton( "Close", Form("BDT_DeleteTBar(%i)", BDT_Global__cbar.size()-1), "Close this control bar", "button" );
00498 #endif
00499
00500
00501
00502 cbar->SetTextColor("blue");
00503
00504
00505 cbar->Show();
00506 }
00507
00508 void BDT_DeleteTBar(int i)
00509 {
00510
00511 StatDialogBDT::Delete();
00512 TMVAGlob::DestroyCanvases();
00513
00514 delete BDT_Global__cbar[i];
00515 BDT_Global__cbar[i] = 0;
00516 }
00517
00518
00519
00520 void BDT_Reg( Int_t itree, TString wfile = "weights/TMVARegression_BDT.weights.xml", TString methName = "BDT", Bool_t useTMVAStyle = kTRUE )
00521 {
00522
00523 StatDialogBDT::Delete();
00524 TMVAGlob::DestroyCanvases();
00525
00526
00527 if(!wfile.EndsWith(".xml") ){
00528 ifstream fin( wfile );
00529 if (!fin.good( )) {
00530 cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
00531 return;
00532 }
00533 }
00534 std::cout << "test1";
00535
00536 TMVAGlob::Initialize( useTMVAStyle );
00537
00538 StatDialogBDT* gGui = new StatDialogBDT( gClient->GetRoot(), wfile, methName, itree );
00539
00540 gGui->DrawTree( itree );
00541
00542 gGui->RaiseDialog();
00543 }
00544