BDT.C

Go to the documentation of this file.
00001 #include <iostream>
00002 #include <iomanip>
00003 #include <fstream>
00004 
00005 #include "tmvaglob.C"
00006 
00007 #include "RQ_OBJECT.h"
00008 
00009 #include "TROOT.h"
00010 #include "TStyle.h"
00011 #include "TPad.h"
00012 #include "TCanvas.h"
00013 #include "TLine.h"
00014 #include "TFile.h"
00015 #include "TColor.h"
00016 #include "TPaveText.h"
00017 #include "TObjString.h"
00018 #include "TControlBar.h"
00019 
00020 #include "TGWindow.h"
00021 #include "TGButton.h"
00022 #include "TGLabel.h"
00023 #include "TGNumberEntry.h"
00024 
00025 #include "TMVA/DecisionTree.h"
00026 #include "TMVA/Tools.h"
00027 #include "TXMLEngine.h"
00028 
00029 // Uncomment this only if the link problem is solved. The include statement tends
00030 // to use the ROOT classes rather than the local TMVA release
00031 // #include "TMVA/DecisionTree.h"
00032 // #include "TMVA/DecisionTreeNode.h"
00033 
00034 // this macro displays a decision tree read in from the weight file
00035 
00036 static const Int_t kSigColorF = TColor::GetColor( "#2244a5" );  // novel blue 
00037 static const Int_t kBkgColorF = TColor::GetColor( "#dd0033" );  // novel red  
00038 static const Int_t kIntColorF = TColor::GetColor( "#33aa77" );  // novel green
00039 
00040 static const Int_t kSigColorT = 10;
00041 static const Int_t kBkgColorT = 10;
00042 static const Int_t kIntColorT = 10;
00043 
00044 enum PlotType { EffPurity = 0 };
00045 
00046 class StatDialogBDT {  
00047 
00048    RQ_OBJECT("StatDialogBDT")
00049 
00050  public:
00051 
00052    StatDialogBDT( const TGWindow* p, TString wfile = "weights/TMVAClassification_BDT.weights.txt", 
00053                   TString methName = "BDT", Int_t itree = 0 );
00054    virtual ~StatDialogBDT() {
00055       TMVA::DecisionTreeNode::fgIsTraining=false;
00056       fThis = 0;
00057       fMain->CloseWindow();
00058       fMain->Cleanup();
00059       if(gROOT->GetListOfCanvases()->FindObject(fCanvas))
00060         delete fCanvas; 
00061    }
00062    
00063    // draw method
00064    void DrawTree( Int_t itree );
00065 
00066    void RaiseDialog() { if (fMain) { fMain->RaiseWindow(); fMain->Layout(); fMain->MapWindow(); } }
00067    
00068  private:
00069    
00070    TGMainFrame *fMain;
00071    Int_t        fItree;
00072    Int_t        fNtrees;
00073    TCanvas*     fCanvas;
00074 
00075    TGNumberEntry* fInput;
00076 
00077    TGHorizontalFrame* fButtons;
00078    TGTextButton* fDrawButton;
00079    TGTextButton* fCloseButton;
00080 
00081    void UpdateCanvases();
00082 
00083    // draw methods
00084    TMVA::DecisionTree* ReadTree( TString * &vars, Int_t itree );
00085    void                DrawNode( TMVA::DecisionTreeNode *n, 
00086                                  Double_t x, Double_t y, Double_t xscale,  Double_t yscale, TString* vars );
00087    void GetNtrees();
00088 
00089    TString fWfile;
00090    TString fMethName;
00091 
00092  public:
00093 
00094    // static function for external deletion
00095    static void Delete() { if (fThis != 0) { delete fThis; fThis = 0; } }
00096 
00097    // slots
00098    void SetItree(); //*SIGNAL*
00099    void Redraw(); //*SIGNAL*
00100    void Close(); //*SIGNAL*
00101 
00102  private:
00103 
00104    static StatDialogBDT* fThis;
00105 
00106 };
00107 
00108 StatDialogBDT* StatDialogBDT::fThis = 0;
00109 
00110 void StatDialogBDT::SetItree() 
00111 {
00112    fItree = Int_t(fInput->GetNumber());
00113 }
00114 
00115 void StatDialogBDT::Redraw() 
00116 {
00117    UpdateCanvases();
00118 }
00119 
00120 void StatDialogBDT::Close() 
00121 {
00122    delete this;
00123 }
00124 
00125 StatDialogBDT::StatDialogBDT( const TGWindow* p, TString wfile, TString methName, Int_t itree )
00126    : fMain( 0 ),
00127      fItree(itree),
00128      fNtrees(0),
00129      fCanvas(0),
00130      fInput(0),
00131      fButtons(0),
00132      fDrawButton(0),
00133      fCloseButton(0),
00134      fWfile( wfile ),
00135      fMethName( methName )
00136 {
00137    UInt_t totalWidth  = 500;
00138    UInt_t totalHeight = 200;
00139 
00140    fThis = this;
00141 
00142    TMVA::DecisionTreeNode::fgIsTraining=true;
00143 
00144    // read number of decision trees from weight file
00145    GetNtrees();
00146 
00147    // main frame
00148    fMain = new TGMainFrame(p, totalWidth, totalHeight, kMainFrame | kVerticalFrame);
00149 
00150    TGLabel *sigLab = new TGLabel( fMain, Form( "Decision tree [%i-%i]",0,fNtrees-1 ) );
00151    fMain->AddFrame(sigLab, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00152 
00153    fInput = new TGNumberEntry(fMain, (Double_t) fItree,5,-1,(TGNumberFormat::EStyle) 5);
00154    fMain->AddFrame(fInput, new TGLayoutHints(kLHintsLeft | kLHintsTop,5,5,5,5));
00155    fInput->Resize(100,24);
00156    fInput->SetLimits(TGNumberFormat::kNELLimitMinMax,0,fNtrees-1);
00157 
00158    fButtons = new TGHorizontalFrame(fMain, totalWidth,30);
00159 
00160    fCloseButton = new TGTextButton(fButtons,"&Close");
00161    fButtons->AddFrame(fCloseButton, new TGLayoutHints(kLHintsLeft | kLHintsTop));
00162 
00163    fDrawButton = new TGTextButton(fButtons,"&Draw");
00164    fButtons->AddFrame(fDrawButton, new TGLayoutHints(kLHintsRight | kLHintsTop,15));
00165   
00166    fMain->AddFrame(fButtons,new TGLayoutHints(kLHintsLeft | kLHintsBottom,5,5,5,5));
00167 
00168    fMain->SetWindowName("Decision tree");
00169    fMain->SetWMPosition(0,0);
00170    fMain->MapSubwindows();
00171    fMain->Resize(fMain->GetDefaultSize());
00172    fMain->MapWindow();
00173 
00174    fInput->Connect("ValueSet(Long_t)","StatDialogBDT",this, "SetItree()");
00175 
00176    fDrawButton->Connect("Clicked()","TGNumberEntry",fInput, "ValueSet(Long_t)");
00177    fDrawButton->Connect("Clicked()", "StatDialogBDT", this, "Redraw()");   
00178 
00179    fCloseButton->Connect("Clicked()", "StatDialogBDT", this, "Close()");
00180 }
00181 
00182 void StatDialogBDT::UpdateCanvases() 
00183 {
00184    DrawTree( fItree );
00185 }
00186 
00187 void StatDialogBDT::GetNtrees()
00188 {
00189    if(!fWfile.EndsWith(".xml") ){
00190       ifstream fin( fWfile );
00191       if (!fin.good( )) { // file not found --> Error
00192          cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00193          return;
00194       }
00195    
00196       TString dummy = "";
00197       
00198       // read total number of trees, and check whether requested tree is in range
00199       Int_t nc = 0;
00200       while (!dummy.Contains("NTrees")) { 
00201          fin >> dummy; 
00202          nc++; 
00203          if (nc > 200) {
00204             cout << endl;
00205             cout << "*** Huge problem: could not locate term \"NTrees\" in BDT weight file: " 
00206                  << fWfile << endl;
00207             cout << "==> panic abort (please contact the TMVA authors)" << endl;
00208             cout << endl;
00209             exit(1);
00210          }
00211       }
00212       fin >> dummy; 
00213       fNtrees = dummy.ReplaceAll("\"","").Atoi();
00214       fin.close();
00215    }
00216    else{
00217       void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00218       void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00219       void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00220       while(ch){
00221          TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00222          if(nodeName=="Weights") {
00223             TMVA::gTools().ReadAttr( ch, "NTrees", fNtrees );
00224             break;
00225          }
00226          ch = TMVA::gTools().xmlengine().GetNext(ch);
00227       }
00228    }
00229    cout << "--- Found " << fNtrees << " decision trees in weight file" << endl;
00230 
00231 }
00232 
00233 //_______________________________________________________________________
00234 void StatDialogBDT::DrawNode( TMVA::DecisionTreeNode *n, 
00235                                Double_t x, Double_t y, 
00236                                Double_t xscale,  Double_t yscale, TString * vars) 
00237 {
00238    // recursively puts an entries in the histogram for the node and its daughters
00239    //
00240    Float_t xsize=xscale*1.5;
00241    Float_t ysize=yscale/3;
00242    if (xsize>0.15) xsize=0.1; //xscale/2;
00243    if (n->GetLeft() != NULL){
00244       TLine *a1 = new TLine(x-xscale/4,y-ysize,x-xscale,y-ysize*2);
00245       a1->SetLineWidth(2);
00246       a1->Draw();
00247       DrawNode((TMVA::DecisionTreeNode*) n->GetLeft(), x-xscale, y-yscale, xscale/2, yscale, vars);
00248    }
00249    if (n->GetRight() != NULL){
00250       TLine *a1 = new TLine(x+xscale/4,y-ysize,x+xscale,y-ysize*2);
00251       a1->SetLineWidth(2);
00252       a1->Draw();
00253       DrawNode((TMVA::DecisionTreeNode*) n->GetRight(), x+xscale, y-yscale, xscale/2, yscale, vars  );
00254    }
00255 
00256    //   TPaveText *t = new TPaveText(x-xscale/2,y-yscale/2,x+xscale/2,y+yscale/2, "NDC");
00257    TPaveText *t = new TPaveText(x-xsize,y-ysize,x+xsize,y+ysize, "NDC");
00258 
00259    t->SetBorderSize(1);
00260 
00261    t->SetFillStyle(1);
00262    if      (n->GetNodeType() ==  1) { t->SetFillColor( kSigColorF ); t->SetTextColor( kSigColorT ); }
00263    else if (n->GetNodeType() == -1) { t->SetFillColor( kBkgColorF ); t->SetTextColor( kBkgColorT ); }
00264    else if (n->GetNodeType() ==  0) { t->SetFillColor( kIntColorF ); t->SetTextColor( kIntColorT ); }
00265 
00266    char buffer[25];
00267    sprintf( buffer, "N=%f", n->GetNEvents() );
00268    if (n->GetNEvents()>0) t->AddText(buffer);
00269    sprintf( buffer, "S/(S+B)=%4.3f", n->GetPurity() );
00270    t->AddText(buffer);
00271 
00272    if (n->GetNodeType() == 0){
00273       if (n->GetCutType()){
00274          t->AddText(TString(vars[n->GetSelector()]+">"+=::Form("%5.3g",n->GetCutValue())));
00275       }else{
00276          t->AddText(TString(vars[n->GetSelector()]+"<"+=::Form("%5.3g",n->GetCutValue())));
00277       }
00278    }
00279 
00280    t->Draw();
00281 
00282    return;
00283 }
00284 TMVA::DecisionTree* StatDialogBDT::ReadTree( TString* &vars, Int_t itree )
00285 {
00286    cout << "--- Reading Tree " << itree << " from weight file: " << fWfile << endl;
00287    TMVA::DecisionTree *d = new TMVA::DecisionTree();
00288    if(!fWfile.EndsWith(".xml") ){
00289       ifstream fin( fWfile );
00290       if (!fin.good( )) { // file not found --> Error
00291          cout << "*** ERROR: Weight file: " << fWfile << " does not exist" << endl;
00292          return 0;
00293       }
00294       
00295       TString dummy = "";
00296       
00297       if (itree >= fNtrees) {
00298          cout << "*** ERROR: requested decision tree: " << itree 
00299               << ", but number of trained trees only: " << fNtrees << endl;
00300          return 0;
00301       }
00302       
00303       // file header with name
00304       while (!dummy.Contains("#VAR")) fin >> dummy;
00305       fin >> dummy >> dummy >> dummy; // the rest of header line
00306       
00307       // number of variables
00308       Int_t nVars;
00309       fin >> dummy >> nVars;
00310       
00311       // variable mins and maxes
00312       vars = new TString[nVars+1]; // last one is if "fisher cut criterium"
00313       for (Int_t i = 0; i < nVars; i++) fin >> vars[i] >> dummy >> dummy >> dummy >> dummy;
00314       vars[nVars]="FisherCrit";
00315 
00316       char buffer[20];
00317       char line[256];
00318       sprintf(buffer,"Tree %d",itree);
00319 
00320       while (!dummy.Contains(buffer)) {
00321          fin.getline(line,256);
00322          dummy = TString(line);
00323       }
00324 
00325       d->Read(fin);
00326       
00327       fin.close();
00328    }
00329    else{
00330      if (itree >= fNtrees) {
00331          cout << "*** ERROR: requested decision tree: " << itree 
00332               << ", but number of trained trees only: " << fNtrees << endl;
00333          return 0;
00334       }
00335      Int_t nVars;
00336       void* doc = TMVA::gTools().xmlengine().ParseFile(fWfile);
00337       void* rootnode = TMVA::gTools().xmlengine().DocGetRootElement(doc);
00338       void* ch = TMVA::gTools().xmlengine().GetChild(rootnode);
00339       while(ch){
00340          TString nodeName = TString( TMVA::gTools().xmlengine().GetNodeName(ch) );
00341          if(nodeName=="Variables"){
00342             TMVA::gTools().ReadAttr( ch, "NVar", nVars);
00343             vars = new TString[nVars+1]; 
00344             void* varnode =  TMVA::gTools().xmlengine().GetChild(ch);
00345             for (Int_t i = 0; i < nVars; i++){
00346                TMVA::gTools().ReadAttr( varnode, "Expression", vars[i]);
00347                varnode =  TMVA::gTools().xmlengine().GetNext(varnode);
00348             }
00349             vars[nVars]="FisherCrit";
00350          }
00351          if(nodeName=="Weights") break;
00352          ch = TMVA::gTools().xmlengine().GetNext(ch);
00353       }
00354       ch = TMVA::gTools().xmlengine().GetChild(ch);
00355       for (int i=0; i<itree; i++) ch = TMVA::gTools().xmlengine().GetNext(ch);
00356       d->ReadXML(ch);
00357    }
00358    return d;
00359 }
00360 
00361 //_______________________________________________________________________
00362 void StatDialogBDT::DrawTree( Int_t itree )
00363 {
00364    TString *vars;   
00365    TMVA::DecisionTree* d = ReadTree( vars, itree );
00366    if (d == 0) return;
00367 
00368    UInt_t   depth = d->GetTotalTreeDepth();
00369    Double_t ystep = 1.0/(depth + 1.0);
00370 
00371    cout << "--- Tree depth: " << depth << endl;
00372 
00373    TStyle* TMVAStyle   = gROOT->GetStyle("Plain"); // our style is based on Plain
00374    Int_t   canvasColor = TMVAStyle->GetCanvasColor(); // backup
00375 
00376    TString cbuffer = Form( "Reading weight file: %s", fWfile.Data() );
00377    TString tbuffer = Form( "Decision Tree no.: %d", itree );
00378    if (!fCanvas) fCanvas = new TCanvas( "c1", cbuffer, 200, 0, 1000, 600 ); 
00379    else          fCanvas->Clear();
00380    fCanvas->Draw();   
00381 
00382    DrawNode( (TMVA::DecisionTreeNode*)d->GetRoot(), 0.5, 1.-0.5*ystep, 0.25, ystep ,vars);
00383   
00384    // make the legend
00385    Double_t yup=0.99;
00386    Double_t ydown=yup-ystep/2.5;
00387    Double_t dy= ystep/2.5 * 0.2;
00388  
00389    TPaveText *whichTree = new TPaveText(0.85,ydown,0.98,yup, "NDC");
00390    whichTree->SetBorderSize(1);
00391    whichTree->SetFillStyle(1);
00392    whichTree->SetFillColor( TColor::GetColor( "#ffff33" ) );
00393    whichTree->AddText( tbuffer );
00394    whichTree->Draw();
00395 
00396    TPaveText *intermediate = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00397    intermediate->SetBorderSize(1);
00398    intermediate->SetFillStyle(1);
00399    intermediate->SetFillColor( kIntColorF );
00400    intermediate->AddText("Intermediate Nodes");
00401    intermediate->SetTextColor( kIntColorT );
00402    intermediate->Draw();
00403 
00404    ydown = ydown - ystep/2.5 -dy;
00405    yup   = yup - ystep/2.5 -dy;
00406    TPaveText *signalleaf = new TPaveText(0.02,ydown ,0.15,yup, "NDC");
00407    signalleaf->SetBorderSize(1);
00408    signalleaf->SetFillStyle(1);
00409    signalleaf->SetFillColor( kSigColorF );
00410    signalleaf->AddText("Signal Leaf Nodes");
00411    signalleaf->SetTextColor( kSigColorT );
00412    signalleaf->Draw();
00413 
00414    ydown = ydown - ystep/2.5 -dy;
00415    yup   = yup - ystep/2.5 -dy;
00416    TPaveText *backgroundleaf = new TPaveText(0.02,ydown,0.15,yup, "NDC");
00417    backgroundleaf->SetBorderSize(1);
00418    backgroundleaf->SetFillStyle(1);
00419    backgroundleaf->SetFillColor( kBkgColorF );
00420 
00421    backgroundleaf->AddText("Backgr. Leaf Nodes");
00422    backgroundleaf->SetTextColor( kBkgColorT );
00423    backgroundleaf->Draw();
00424 
00425    fCanvas->Update();
00426    TString fname = Form("plots/%s_%i", fMethName.Data(), itree );
00427    cout << "--- Creating image: " << fname << endl;
00428    TMVAGlob::imgconv( fCanvas, fname );   
00429 
00430    TMVAStyle->SetCanvasColor( canvasColor );
00431 }   
00432       
00433 // ========================================================================================
00434 
00435 static std::vector<TControlBar*> BDT_Global__cbar;
00436 
00437 // intermediate GUI
00438 void BDT( const TString& fin = "TMVA.root" )
00439 {
00440    // --- read the available BDT weight files
00441 
00442    // destroy all open cavases
00443    TMVAGlob::DestroyCanvases(); 
00444 
00445    // checks if file with name "fin" is already open, and if not opens one
00446    TFile* file = TMVAGlob::OpenFile( fin );  
00447 
00448    TDirectory* dir = file->GetDirectory( "Method_BDT" );
00449    if (!dir) {
00450       cout << "*** Error in macro \"BDT.C\": cannot find directory \"Method_BDT\" in file: " << fin << endl;
00451       return;
00452    }
00453 
00454    // read all directories
00455    TIter next( dir->GetListOfKeys() );
00456    TKey *key(0);   
00457    std::vector<TString> methname;   
00458    std::vector<TString> path;   
00459    std::vector<TString> wfile;   
00460    while ((key = (TKey*)next())) {
00461       TDirectory* mdir = dir->GetDirectory( key->GetName() );
00462       if (!mdir) {
00463          cout << "*** Error in macro \"BDT.C\": cannot find sub-directory: " << key->GetName() 
00464               << " in directory: " << dir->GetName() << endl;
00465          return;
00466       }
00467 
00468       // retrieve weight file name and path
00469       TObjString* strPath = (TObjString*)mdir->Get( "TrainingPath" );
00470       TObjString* strWFile = (TObjString*)mdir->Get( "WeightFileName" );
00471       if (!strPath || !strWFile) {
00472          cout << "*** Error in macro \"BDT.C\": could not find TObjStrings \"TrainingPath\" and/or \"WeightFileName\" *** " << endl;
00473          cout << "*** Maybe you are using TMVA >= 3.8.15 with an older training target file ? *** " << endl;
00474          return;
00475       }
00476 
00477       methname.push_back( key->GetName() );
00478       path    .push_back( strPath->GetString() );
00479       wfile   .push_back( strWFile->GetString() );
00480    }
00481 
00482    // create the control bar
00483    TControlBar* cbar = new TControlBar( "vertical", "Choose weight file:", 50, 50 );
00484    BDT_Global__cbar.push_back(cbar);
00485 
00486    for (UInt_t im=0; im<path.size(); im++) {
00487       TString fname = path[im];
00488       if (fname[fname.Length()-1] != '/') fname += "/";
00489       fname += wfile[im];
00490       TString macro = Form( ".x BDT.C+\(0,\"%s\",\"%s\")", fname.Data(), methname[im].Data() );
00491       cbar->AddButton( fname, macro, "Plot decision trees from this weight file", "button" );
00492    }
00493 
00494    // *** problems with this button in ROOT 5.19 ***
00495    #if ROOT_VERSION_CODE < ROOT_VERSION(5,19,0)
00496    cbar->AddButton( "Close", Form("BDT_DeleteTBar(%i)", BDT_Global__cbar.size()-1), "Close this control bar", "button" );
00497    #endif
00498    // **********************************************
00499 
00500    // set the style 
00501    cbar->SetTextColor("blue");
00502 
00503    // draw
00504    cbar->Show();   
00505 }
00506 
00507 void BDT_DeleteTBar(int i)
00508 {
00509    // destroy all open canvases
00510    StatDialogBDT::Delete();
00511    TMVAGlob::DestroyCanvases();
00512 
00513    delete BDT_Global__cbar[i];
00514    BDT_Global__cbar[i] = 0;
00515 }
00516 
00517 // input: - No. of tree
00518 //        - the weight file from which the tree is read
00519 void BDT( Int_t itree, TString wfile = "weights/TMVAnalysis_test_BDT.weights.txt", TString methName = "BDT", Bool_t useTMVAStyle = kTRUE ) 
00520 {
00521    // destroy possibly existing dialog windows and/or canvases
00522    StatDialogBDT::Delete();
00523    TMVAGlob::DestroyCanvases(); 
00524 
00525    // quick check if weight file exist
00526    if(!wfile.EndsWith(".xml") ){
00527       ifstream fin( wfile );
00528       if (!fin.good( )) { // file not found --> Error
00529          cout << "*** ERROR: Weight file: " << wfile << " does not exist" << endl;
00530          return;
00531       }
00532    }
00533    std::cout << "test1";
00534    // set style and remove existing canvas'
00535    TMVAGlob::Initialize( useTMVAStyle );
00536 
00537    StatDialogBDT* gGui = new StatDialogBDT( gClient->GetRoot(), wfile, methName, itree );
00538 
00539    gGui->DrawTree( itree );
00540 
00541    gGui->RaiseDialog();
00542 }
00543 

Generated on Tue Jul 5 15:25:40 2011 for ROOT_528-00b_version by  doxygen 1.5.1