BDT_Reg.C

Go to the documentation of this file.
00001 #include <iostream>
00002 #include <iomanip>
00003 #include <fstream>
00004 
00005 #include "tmvaglob.C"
00006 
00007 #include "RQ_OBJECT.h"
00008 
00009 #include "TROOT.h"
00010 #include "TStyle.h"
00011 #include "TPad.h"
00012 #include "TCanvas.h"
00013 #include "TLine.h"
00014 #include "TFile.h"
00015 #include "TColor.h"
00016 #include "TPaveText.h"
00017 #include "TObjString.h"
00018 #include "TControlBar.h"
00019 
00020 #include "TGWindow.h"
00021 #include "TGButton.h"
00022 #include "TGLabel.h"
00023 #include "TGNumberEntry.h"
00024 
00025 #include "TMVA/DecisionTree.h"
00026 #include "TMVA/Tools.h"
00027 #include "TXMLEngine.h"
00028 
00029 // Uncomment this only if the link problem is solved. The include statement tends
00030 // to use the ROOT classes rather than the local TMVA release
00031 // #include "TMVA/DecisionTree.h"
00032 // #include "TMVA/DecisionTreeNode.h"
00033 
00034 // this macro displays a decision tree read in from the weight file
00035 
00036 static const Int_t kSigColorF = TColor::GetColor( "#2244a5" );  // novel blue 
00037 static const Int_t kBkgColorF = TColor::GetColor( "#dd0033" );  // novel red  
00038 static const Int_t kIntColorF = TColor::GetColor( "#33aa77" );  // novel green
00039 
00040 static const Int_t kSigColorT = 10;
00041 static const Int_t kBkgColorT = 10;
00042 static const Int_t kIntColorT = 10;
00043 
00044 enum PlotType { EffPurity = 0 };
00045 
00046 class StatDialogBDT {  
00047 
00048    RQ_OBJECT("StatDialogBDT")
00049 
00050  public:
00051 
00052    StatDialogBDT( const TGWindow* p, TString wfile = "weights/TMVARegression_BDT.weights.xml", 
00053                   TString methName = "BDT", Int_t itree = 0 );
00054    virtual ~StatDialogBDT() {
00055       TMVA::DecisionTreeNode::fgIsTraining=false;
00056       fThis = 0;
00057       fMain->CloseWindow();
00058       fMain->Cleanup();
00059       if(gROOT->GetListOfCanvases()->FindObject(fCanvas))
00060         delete fCanvas; 
00061    }
00062    
00063    // draw method
00064    void DrawTree( Int_t itree );
00065 
00066    void RaiseDialog() { if (fMain) { fMain->RaiseWindow(); fMain->Layout(); fMain->MapWindow(); } }
00067    
00068  private:
00069    
00070    TGMainFrame *fMain;
00071    Int_t        fItree;
00072    Int_t        fNtrees;
00073    TCanvas*     fCanvas;
00074 
00075    TGNumberEntry* fInput;
00076 
00077    TGHorizontalFrame* fButtons;
00078    TGTextButton* fDrawButton;
00079    TGTextButton* fCloseButton;
00080 
00081    void UpdateCanvases();
00082 
00083    // draw methods
00084    TMVA::DecisionTree* ReadTree( TString * &vars, Int_t itree );
00085    void                DrawNode( TMVA::DecisionTreeNode *n, 
00086                                  Double_t x, Double_t y, Double_t xscale,  Double_t yscale, TString* vars );
00087    void GetNtrees();
00088 
00089    TString fWfile;
00090    TString fMethName;
00091 
00092  public:
00093 
00094    // static function for external deletion
00095    static void Delete() { if (fThis != 0) { delete fThis; fThis = 0; } }
00096 
00097    // slots
00098    void SetItree(); //*SIGNAL*
00099    void Redraw(); //*SIGNAL*
00100    void Close(); //*SIGNAL*
00101 
00102  private:
00103 
00104    static StatDialogBDT* fThis;
00105 
00106 };
00107 
00108 StatDialogBDT* StatDialogBDT::fThis = 0;
00109 
00110 void StatDialogBDT::SetItree() 
00111 {
00112    fItree = Int_t(fInput->GetNumber());
00113 }
00114 
00115 void StatDialogBDT::Redraw() 
00116 {
00117    UpdateCanvases();
00118 }
00119 
00120 void StatDialogBDT::Close() 
00121 {
00122    delete this;
00123 }
00124 
00125 StatDialogBDT::StatDialogBDT( const TGWindow* p, TString wfile, TString methName, Int_t itree )
00126    : fMain( 0 ),
00127      fItree(itree),
00128      fNtrees(0),
00129      fCanvas(0),
00130      fInput(0),
00131      fButtons(0),
00132      fDrawButton(0),
00133      fCloseButton(0),
00134      fWfile( wfile ),
00135      fMethName( methName )
00136 {
00137    UInt_t totalWidth  = 500;
00138    UInt_t totalHeight = 200;
00139 
00140    fThis = this;
00141 
00142    // read number of decision trees from weight file
00143    GetNtrees();
00144 
00145    // main frame
00146    fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
00147 
00148    TGLabel *sigLab = new TGLabel( fMain, Form( "Regression tree [%i-%i]",0,fNtrees-1 ) );
00149    fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00150 
00151    fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
00152    fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00153    fInput->Resize(100,24);
00154    fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
00155 
00156    fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
00157 
00158    fCloseButton = new TGTextButton(fButtons,"&Close");
00159    fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
00160 
00161    fDrawButton = new TGTextButton(fButtons,"&Draw");
00162    fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
00163   
00164    fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
00165 
00166    fMain->SetWindowName("Regression tree");
00167    fMain->SetWMPosition(0,0);
00168    fMain->MapSubwindows();
00169    fMain->Resize(fMain->GetDefaultSize());
00170    fMain->MapWindow();
00171 
00172    fInput->Connect("ValueSet(Long_t)","StatDialogBDT",this, "SetItree()");
00173 
00174    fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
00175    fDrawButton->Connect("Clicked()", "StatDialogBDT", this, "Redraw()");   
00176 
00177    fCloseButton->Connect("Clicked()", "StatDialogBDT", this, "Close()");
00178 }
00179 
00180 void StatDialogBDT::UpdateCanvases() 
00181 {
00182    DrawTree( fItree );
00183 }
00184 
00185 void StatDialogBDT::GetNtrees()
00186 {
00187    if(!fWfile.EndsWith(".xml") ){
00188       ifstream fin( fWfile );
00189       if (!fin.good( )) { // file not found --> Error
00190          cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00191          return;
00192       }
00193    
00194       TString dummy = "";
00195       
00196       // read total number of trees, and check whether requested tree is in range
00197       Int_t nc = 0;
00198       while (!dummy.Contains("NTrees")) { 
00199          fin >> dummy; 
00200          nc++; 
00201          if (nc > 200) {
00202             cout << endl;
00203             cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: " 
00204                  << fWfile << endl;
00205             cout << "==> panic abort (please contact the TMVA authors)" << endl;
00206             cout << endl;
00207             exit(1);
00208          }
00209       }
00210       fin >> dummy; 
00211       fNtrees = dummy.ReplaceAll("\"","").Atoi();
00212       fin.close();
00213    }
00214    else{
00215       void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00216       void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00217       void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00218       while(ch){
00219          TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00220          if(nodeName=="Weights") {
00221             TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
00222             break;
00223          }
00224          ch = TMVA::gTools().xmlengine().GetNext(ch);
00225       }
00226    }
00227    cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
00228 
00229 }
00230 
00231 //_______________________________________________________________________
00232 void StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n, 
00233                                Double_t x, Double_t y, 
00234                                Double_t xscale,  Double_t yscale, TString * vars) 
00235 {
00236    // recursively puts an entries in the histogram for the node and its daughters
00237    //
00238    Float_t xsize=xscale*1.5;
00239    Float_t ysize=yscale/3;
00240    if (xsize>0.15) xsize=0.1;
00241    if (n->GetLeft() != NULL){
00242       TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
00243       a1->SetLineWidth(2);
00244       a1->Draw();
00245       DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
00246    }
00247    if (n->GetRight() != NULL){
00248       TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
00249       a1->SetLineWidth(2);
00250       a1->Draw();
00251       DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars  );
00252    }
00253 
00254    //   TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
00255    TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
00256 
00257    t->SetBorderSize(1);
00258 
00259    t->SetFillStyle(1);
00260    if      (n->GetNodeType() ==  1) { t->SetFillColor( kSigColorF ); t->SetTextColor( kSigColorT ); }
00261    else if (n->GetNodeType() == -1) { t->SetFillColor( kBkgColorF ); t->SetTextColor( kBkgColorT ); }
00262    else if (n->GetNodeType() ==  0) { t->SetFillColor( kIntColorF ); t->SetTextColor( kIntColorT ); }
00263 
00264    char buffer[25];
00265    //   sprintf( buffer, "N=%f", n->GetNEvents() );
00266    //   t->AddText(buffer);
00267    sprintf( buffer, "R=%4.1f +- %4.1f", n->GetResponse(),n->GetRMS() );
00268    t->AddText(buffer);
00269 
00270    if (n->GetNodeType() == 0){
00271       if (n->GetCutType()){
00272          t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
00273       }else{
00274          t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
00275       }
00276    }
00277 
00278    t->Draw();
00279 
00280    return;
00281 }
00282 
00283 TMVA::DecisionTree* StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
00284 {
00285    cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
00286    TMVA::DecisionTree *d = new TMVA::DecisionTree();
00287 
00288 
00289    if(!fWfile.EndsWith(".xml") ){
00290 
00291       ifstream fin( fWfile );
00292       if (!fin.good( )) { // file not found --> Error
00293          cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00294          return 0;
00295       }
00296       TString dummy = "";
00297       
00298       if (itree >= fNtrees) {
00299          cout << "*** ERROR: requested decision tree: " << itree 
00300               << ", but number of trained trees only: " << fNtrees << endl;
00301          return 0;
00302       }
00303       
00304       // file header with name
00305       while (!dummy.Contains("#VAR")) fin >> dummy;
00306       fin >> dummy >> dummy >> dummy; // the rest of header line
00307 
00308       // number of variables
00309       Int_t nVars;
00310       fin >> dummy >> nVars;
00311       
00312       // variable mins and maxes
00313       vars = new TString[nVars+1];
00314       for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
00315       vars[nVars]="FisherCrit";
00316       
00317       char buffer[20];
00318       char line[256];
00319       sprintf(buffer,"Tree %d",itree);
00320 
00321       while (!dummy.Contains(buffer)) {
00322          fin.getline(line,256);
00323          dummy = TString(line);
00324       }
00325 
00326       d->Read(fin);
00327       
00328       fin.close();
00329    }
00330    else{
00331       if (itree >= fNtrees) {
00332          cout << "*** ERROR: requested decision tree: " << itree 
00333                << ", but number of trained trees only: " << fNtrees << endl;
00334          return 0;
00335       }
00336       Int_t nVars;
00337       void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00338       void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00339       void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00340       while(ch){
00341          TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00342          if(nodeName=="Variables"){
00343             TMVA::gTools().ReadAttr( ch, "NVar", nVars);
00344             vars = new TString[nVars+1]; 
00345             void* varnode =  TMVA::gTools().xmlengine().GetChild(ch);
00346             for (Int_t i = 0; i < nVars; i++){
00347                TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
00348                varnode =  TMVA::gTools().xmlengine().GetNext(varnode);
00349             }
00350             vars[nVars]="FisherCrit";
00351          }
00352          if(nodeName=="Weights") break;
00353          ch = TMVA::gTools().xmlengine().GetNext(ch);
00354       }
00355       ch = TMVA::gTools().xmlengine().GetChild(ch);
00356       for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
00357       d->ReadXML(ch);
00358    }
00359    return d;
00360 }
00361 
00362 //_______________________________________________________________________
00363 void StatDialogBDT::DrawTree( Int_t itree )
00364 {
00365    TString *vars;   
00366 
00367    TMVA::DecisionTree* d = ReadTree( vars, itree );
00368    if (d == 0) return;
00369 
00370    UInt_t   depth = d->GetTotalTreeDepth();
00371    Double_t ystep = 1.0/(depth + 1.0);
00372 
00373    cout << "--- Tree depth: " << depth << endl;
00374 
00375    TStyle* TMVAStyle   = gROOT->GetStyle("Plain"); // our style is based on Plain
00376    Int_t   canvasColor = TMVAStyle->GetCanvasColor(); // backup
00377 
00378    TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
00379    TString tbuffer = Form( "Regression Tree no.: %d", itree );
00380    if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 ); 
00381    else          fCanvas->Clear();
00382    fCanvas->Draw();   
00383    DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
00384   
00385    // make the legend
00386    Double_t yup=0.99;
00387    Double_t ydown=yup-ystep/2.5;
00388    Double_t dy= ystep/2.5 * 0.2;
00389  
00390    TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
00391    whichTree->SetBorderSize(1);
00392    whichTree->SetFillStyle(1);
00393    whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
00394    whichTree->AddText( tbuffer );
00395    whichTree->Draw();
00396 
00397    TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00398    intermediate->SetBorderSize(1);
00399    intermediate->SetFillStyle(1);
00400    intermediate->SetFillColor( kIntColorF );
00401    intermediate->AddText("Intermediate Nodes");
00402    intermediate->SetTextColor( kIntColorT );
00403    intermediate->Draw();
00404 
00405    ydown = ydown - ystep/2.5 -dy;
00406    yup   = yup - ystep/2.5 -dy;
00407    TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
00408    signalleaf->SetBorderSize(1);
00409    signalleaf->SetFillStyle(1);
00410    signalleaf->SetFillColor( kSigColorF );
00411    signalleaf->AddText("Leaf Nodes");
00412    signalleaf->SetTextColor( kSigColorT );
00413    signalleaf->Draw();
00414 /*
00415    ydown = ydown - ystep/2.5 -dy;
00416    yup   = yup - ystep/2.5 -dy;
00417    TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00418    backgroundleaf->SetBorderSize(1);
00419    backgroundleaf->SetFillStyle(1);
00420    backgroundleaf->SetFillColor( kBkgColorF );
00421 
00422    backgroundleaf->AddText("Backgr. Leaf Nodes");
00423    backgroundleaf->SetTextColor( kBkgColorT );
00424    backgroundleaf->Draw();
00425 */
00426    fCanvas->Update();
00427    TString fname = Form("plots/%s_%i", fMethName.Data(), itree );
00428    cout << "--- Creating image: " << fname << endl;
00429    TMVAGlob::imgconv( fCanvas, fname );   
00430 
00431    TMVAStyle->SetCanvasColor( canvasColor );
00432 }   
00433       
00434 // ========================================================================================
00435 
00436 static std::vector<TControlBar*> BDT_Global__cbar;
00437 
00438 // intermediate GUI
00439 void BDT_Reg( const TString& fin = "TMVAReg.root" )
00440 {
00441    // --- read the available BDT weight files
00442 
00443    // destroy all open cavases
00444    TMVAGlob::DestroyCanvases(); 
00445 
00446    // checks if file with name "fin" is already open, and if not opens one
00447    TFile* file = TMVAGlob::OpenFile( fin );  
00448 
00449    TDirectory* dir = file->GetDirectory( "Method_BDT" );
00450    if (!dir) {
00451       cout << "*** Error in macro \"BDT_Reg.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
00452       return;
00453    }
00454 
00455    // read all directories
00456    TIter next( dir->GetListOfKeys() );
00457    TKey *key(0);   
00458    std::vector<TString> methname;   
00459    std::vector<TString> path;   
00460    std::vector<TString> wfile;   
00461    while ((key = (TKey*)next())) {
00462       TDirectory* mdir = dir->GetDirectory( key->GetName() );
00463       if (!mdir) {
00464          cout << "*** Error in macro \"BDT_Reg.C\": cannot find sub-directory: " << key->GetName() 
00465               << " in directory: " << dir->GetName() << endl;
00466          return;
00467       }
00468 
00469       // retrieve weight file name and path
00470       TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
00471       TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
00472       if (!strPath || !strWFile) {
00473          cout << "*** Error in macro \"BDT_Reg.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
00474          cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
00475          return;
00476       }
00477 
00478       methname.push_back( key->GetName() );
00479       path    .push_back( strPath->GetString() );
00480       wfile   .push_back( strWFile->GetString() );
00481    }
00482 
00483    // create the control bar
00484    TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
00485    BDT_Global__cbar.push_back(cbar);
00486 
00487    for (UInt_t im=0; im<path.size(); im++) {
00488       TString fname = path[im];
00489       if (fname[fname.Length()-1] != '/') fname += "/";
00490       fname += wfile[im];
00491       TString macro = Form( ".x BDT_Reg.C+\(0,\"%s\",\"%s\")", fname.Data(), methname[im].Data() );
00492       cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
00493    }
00494 
00495    // *** problems with this button in ROOT 5.19 ***
00496    #if ROOT_VERSION_CODE < ROOT_VERSION(5,19,0)
00497    cbar->AddButton( "Close", Form("BDT_DeleteTBar(%i)", BDT_Global__cbar.size()-1), "Close this control bar", "button" );
00498    #endif
00499    // **********************************************
00500 
00501    // set the style 
00502    cbar->SetTextColor("blue");
00503 
00504    // draw
00505    cbar->Show();   
00506 }
00507 
00508 void BDT_DeleteTBar(int i)
00509 {
00510    // destroy all open canvases
00511    StatDialogBDT::Delete();
00512    TMVAGlob::DestroyCanvases();
00513 
00514    delete BDT_Global__cbar[i];
00515    BDT_Global__cbar[i] = 0;
00516 }
00517 
00518 // input: - No. of tree
00519 //        - the weight file from which the tree is read
00520 void BDT_Reg( Int_t itree, TString wfile = "weights/TMVARegression_BDT.weights.xml", TString methName = "BDT", Bool_t useTMVAStyle = kTRUE ) 
00521 {
00522    // destroy possibly existing dialog windows and/or canvases
00523    StatDialogBDT::Delete();
00524    TMVAGlob::DestroyCanvases(); 
00525 
00526    // quick check if weight file exist
00527    if(!wfile.EndsWith(".xml") ){
00528       ifstream fin( wfile );
00529       if (!fin.good( )) { // file not found --> Error
00530          cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
00531          return;
00532       }
00533    }
00534    std::cout << "test1";
00535    // set style and remove existing canvas'
00536    TMVAGlob::Initialize( useTMVAStyle );
00537 
00538    StatDialogBDT* gGui = new StatDialogBDT( gClient->GetRoot(), wfile, methName, itree );
00539 
00540    gGui->DrawTree( itree );
00541 
00542    gGui->RaiseDialog();
00543 }
00544 

Generated on Tue Jul 5 15:25:40 2011 for ROOT_528-00b_version by  doxygen 1.5.1