DecisionTreeNode.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DecisionTreeNode.cxx 37986 2011-02-04 21:42:15Z pcanal $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : TMVA::DecisionTreeNode                                                *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation of a Decision Tree Node                                    *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00017  *      Eckhard von Toerne <evt@physik.uni-bonn.de>  - U. of Bonn, Germany        *
00018  *                                                                                *
00019  * CopyRight (c) 2009:                                                            *
00020  *      CERN, Switzerland                                                         *
00021  *      U. of Victoria, Canada                                                    *
00022  *      MPI-K Heidelberg, Germany                                                 *
00023  *      U. of Bonn, Germany                                                       *
00024  *                                                                                *
00025  * Redistribution and use in source and binary forms, with or without             *
00026  * modification, are permitted according to the terms listed in LICENSE           *
00027  * (http://tmva.sourceforge.net/LICENSE)                                          *
00028  **********************************************************************************/
00029 
00030 //_______________________________________________________________________
00031 //
00032 // Node for the Decision Tree
00033 //
00034 // The node specifies ONE variable out of the given set of selection variable
00035 // that is used to split the sample which "arrives" at the node, into a left
00036 // (background-enhanced) and a right (signal-enhanced) sample.
00037 //_______________________________________________________________________
00038 
00039 #include <algorithm>
00040 #include <exception>
00041 #include <iomanip>
00042 
00043 #include "TMVA/MsgLogger.h"
00044 #include "TMVA/DecisionTreeNode.h"
00045 #include "TMVA/Tools.h"
00046 #include "TMVA/Event.h"
00047 
00048 using std::string;
00049 
00050 ClassImp(TMVA::DecisionTreeNode)
00051 
00052 TMVA::MsgLogger* TMVA::DecisionTreeNode::fgLogger = 0;
00053 bool     TMVA::DecisionTreeNode::fgIsTraining = false;
00054 
00055 //_______________________________________________________________________
00056 TMVA::DecisionTreeNode::DecisionTreeNode()
00057    : TMVA::Node(),
00058      fCutValue(0),
00059      fCutType ( kTRUE ),
00060      fSelector ( -1 ),
00061      fResponse(-99 ),
00062      fRMS(0),
00063      fNodeType (-99 ),
00064      fPurity (-99),
00065      fIsTerminalNode( kFALSE )
00066 {
00067    // constructor of an essentially "empty" node floating in space
00068    if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
00069 
00070    if (fgIsTraining){
00071       fTrainInfo = new DTNodeTrainingInfo();
00072       //std::cout << "Node constructor with TrainingINFO"<<std::endl;
00073    }
00074    else {
00075       //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
00076       fTrainInfo = 0;
00077    }
00078 }
00079 
00080 //_______________________________________________________________________
00081 TMVA::DecisionTreeNode::DecisionTreeNode(TMVA::Node* p, char pos)
00082    : TMVA::Node(p, pos),
00083      fCutValue( 0 ),
00084      fCutType ( kTRUE ),
00085      fSelector( -1 ),
00086      fResponse(-99 ),
00087      fRMS(0),
00088      fNodeType( -99 ),
00089      fPurity (-99),
00090      fIsTerminalNode( kFALSE )
00091 {
00092    // constructor of a daughter node as a daughter of 'p'
00093    if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
00094 
00095    if (fgIsTraining){
00096       fTrainInfo = new DTNodeTrainingInfo();
00097       //std::cout << "Node constructor with TrainingINFO"<<std::endl;
00098    }
00099    else {
00100       //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
00101       fTrainInfo = 0;
00102    }
00103 }
00104 
00105 //_______________________________________________________________________
00106 TMVA::DecisionTreeNode::DecisionTreeNode(const TMVA::DecisionTreeNode &n,
00107                                          DecisionTreeNode* parent)
00108    : TMVA::Node(n),
00109      fCutValue( n.fCutValue ),
00110      fCutType ( n.fCutType ),
00111      fSelector( n.fSelector ),
00112      fResponse( n.fResponse ),
00113      fRMS     ( n.fRMS),
00114      fNodeType( n.fNodeType ),
00115      fPurity  ( n.fPurity),
00116      fIsTerminalNode( n.fIsTerminalNode )
00117 {
00118    // copy constructor of a node. It will result in an explicit copy of
00119    // the node and recursively all it's daughters
00120    if (!fgLogger) fgLogger = new TMVA::MsgLogger( "DecisionTreeNode" );
00121 
00122    this->SetParent( parent );
00123    if (n.GetLeft() == 0 ) this->SetLeft(NULL);
00124    else this->SetLeft( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetLeft())),this));
00125 
00126    if (n.GetRight() == 0 ) this->SetRight(NULL);
00127    else this->SetRight( new DecisionTreeNode( *((DecisionTreeNode*)(n.GetRight())),this));
00128 
00129    if (fgIsTraining){
00130       fTrainInfo = new DTNodeTrainingInfo(*(n.fTrainInfo));
00131       //std::cout << "Node constructor with TrainingINFO"<<std::endl;
00132    }
00133    else {
00134       //std::cout << "**Node constructor WITHOUT TrainingINFO"<<std::endl;
00135       fTrainInfo = 0;
00136    }
00137 }
00138 
00139 //_______________________________________________________________________
00140 TMVA::DecisionTreeNode::~DecisionTreeNode(){
00141    // destructor
00142    delete fTrainInfo;
00143 }
00144 
00145 
00146 //_______________________________________________________________________
00147 Bool_t TMVA::DecisionTreeNode::GoesRight(const TMVA::Event & e) const
00148 {
00149    // test event if it decends the tree at this node to the right
00150    Bool_t result;
00151    // first check if the fisher criterium is used or ordinary cuts:
00152    if (GetNFisherCoeff() == 0){
00153       
00154       result = (e.GetValue(this->GetSelector()) > this->GetCutValue() );
00155 
00156    }else{
00157       
00158       Double_t fisher = this->GetFisherCoeff(fFisherCoeff.size()-1); // the offset
00159       for (UInt_t ivar=0; ivar<fFisherCoeff.size()-1; ivar++)
00160          fisher += this->GetFisherCoeff(ivar)*(e.GetValue(ivar));
00161 
00162       result = fisher > this->GetCutValue();
00163    }
00164 
00165    if (fCutType == kTRUE) return result; //the cuts are selecting Signal ;
00166    else return !result;
00167 }
00168 
00169 //_______________________________________________________________________
00170 Bool_t TMVA::DecisionTreeNode::GoesLeft(const TMVA::Event & e) const
00171 {
00172    // test event if it decends the tree at this node to the left
00173    if (!this->GoesRight(e)) return kTRUE;
00174    else return kFALSE;
00175 }
00176 
00177 
00178 //_______________________________________________________________________
00179 void TMVA::DecisionTreeNode::SetPurity( void )
00180 {
00181    // return the S/(S+B) (purity) for the node
00182    // REM: even if nodes with purity 0.01 are very PURE background nodes, they still
00183    //      get a small value of the purity.
00184 
00185    if ( ( this->GetNSigEvents() + this->GetNBkgEvents() ) > 0 ) {
00186       fPurity = this->GetNSigEvents() / ( this->GetNSigEvents() + this->GetNBkgEvents());
00187    }
00188    else {
00189       *fgLogger << kINFO << "Zero events in purity calcuation , return purity=0.5" << Endl;
00190       this->Print(*fgLogger);
00191       fPurity = 0.5;
00192    }
00193    return;
00194 }
00195 
00196 // print a node
00197 //_______________________________________________________________________
00198 void TMVA::DecisionTreeNode::Print(ostream& os) const
00199 {
00200    //print the node
00201    os << "< ***  "  << std::endl;
00202    os << " d: "     << this->GetDepth()
00203       << std::setprecision(6)
00204       << "NCoef: "  << this->GetNFisherCoeff();
00205    for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) { os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
00206    os << " ivar: "  << this->GetSelector()
00207       << " cut: "   << this->GetCutValue() 
00208       << " cType: " << this->GetCutType()
00209       << " s: "     << this->GetNSigEvents()
00210       << " b: "     << this->GetNBkgEvents()
00211       << " nEv: "   << this->GetNEvents()
00212       << " suw: "   << this->GetNSigEvents_unweighted()
00213       << " buw: "   << this->GetNBkgEvents_unweighted()
00214       << " nEvuw: " << this->GetNEvents_unweighted()
00215       << " sepI: "  << this->GetSeparationIndex()
00216       << " sepG: "  << this->GetSeparationGain()
00217       << " nType: " << this->GetNodeType()
00218       << std::endl;
00219 
00220    os << "My address is " << long(this) << ", ";
00221    if (this->GetParent() != NULL) os << " parent at addr: "         << long(this->GetParent()) ;
00222    if (this->GetLeft()   != NULL) os << " left daughter at addr: "  << long(this->GetLeft());
00223    if (this->GetRight()  != NULL) os << " right daughter at addr: " << long(this->GetRight()) ;
00224 
00225    os << " **** > " << std::endl;
00226 }
00227 
00228 //_______________________________________________________________________
00229 void TMVA::DecisionTreeNode::PrintRec(ostream& os) const
00230 {
00231    //recursively print the node and its daughters (--> print the 'tree')
00232 
00233    os << this->GetDepth()
00234       << std::setprecision(6)
00235       << " "         << this->GetPos()
00236       << "NCoef: "   << this->GetNFisherCoeff();
00237    for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {os << "fC"<<i<<": " << this->GetFisherCoeff(i);}
00238    os << " ivar: "   << this->GetSelector()
00239       << " cut: "    << this->GetCutValue()
00240       << " cType: "  << this->GetCutType()
00241       << " s: "      << this->GetNSigEvents()
00242       << " b: "      << this->GetNBkgEvents()
00243       << " nEv: "    << this->GetNEvents()
00244       << " suw: "    << this->GetNSigEvents_unweighted()
00245       << " buw: "    << this->GetNBkgEvents_unweighted()
00246       << " nEvuw: "  << this->GetNEvents_unweighted()
00247       << " sepI: "   << this->GetSeparationIndex()
00248       << " sepG: "   << this->GetSeparationGain()
00249       << " res: "    << this->GetResponse()
00250       << " rms: "    << this->GetRMS()
00251       << " nType: "  << this->GetNodeType();
00252    if (this->GetCC() > 10000000000000.) os << " CC: " << 100000. << std::endl;
00253    else os << " CC: "  << this->GetCC() << std::endl;
00254 
00255    if (this->GetLeft()  != NULL) this->GetLeft() ->PrintRec(os);
00256    if (this->GetRight() != NULL) this->GetRight()->PrintRec(os);
00257 }
00258 
00259 //_______________________________________________________________________
00260 Bool_t TMVA::DecisionTreeNode::ReadDataRecord( istream& is, UInt_t tmva_Version_Code )
00261 {
00262    // Read the data block
00263 
00264    string tmp;
00265 
00266    Float_t cutVal, cutType, nsig, nbkg, nEv, nsig_unweighted, nbkg_unweighted, nEv_unweighted;
00267    Float_t separationIndex, separationGain, response(-99), cc(0);
00268    Int_t   depth, ivar, nodeType;
00269    ULong_t lseq;
00270    char pos;
00271 
00272    is >> depth;                                         // 2
00273    if ( depth==-1 ) { return kFALSE; }
00274    //   if ( depth==-1 ) { delete this; return kFALSE; }
00275    is >> pos ;                                          // r
00276    this->SetDepth(depth);
00277    this->SetPos(pos);
00278 
00279    if (tmva_Version_Code < TMVA_VERSION(4,0,0)) {
00280       is >> tmp >> lseq
00281          >> tmp >> ivar
00282          >> tmp >> cutVal
00283          >> tmp >> cutType
00284          >> tmp >> nsig
00285          >> tmp >> nbkg
00286          >> tmp >> nEv
00287          >> tmp >> nsig_unweighted
00288          >> tmp >> nbkg_unweighted
00289          >> tmp >> nEv_unweighted
00290          >> tmp >> separationIndex
00291          >> tmp >> separationGain
00292          >> tmp >> nodeType;
00293    } else {
00294       is >> tmp >> lseq
00295          >> tmp >> ivar
00296          >> tmp >> cutVal
00297          >> tmp >> cutType
00298          >> tmp >> nsig
00299          >> tmp >> nbkg
00300          >> tmp >> nEv
00301          >> tmp >> nsig_unweighted
00302          >> tmp >> nbkg_unweighted
00303          >> tmp >> nEv_unweighted
00304          >> tmp >> separationIndex
00305          >> tmp >> separationGain
00306          >> tmp >> response
00307          >> tmp >> nodeType
00308          >> tmp >> cc;
00309    }
00310 
00311    this->SetSelector((UInt_t)ivar);
00312    this->SetCutValue(cutVal);
00313    this->SetCutType(cutType);
00314    this->SetNodeType(nodeType);
00315    if (fTrainInfo){
00316       this->SetNSigEvents(nsig);
00317       this->SetNBkgEvents(nbkg);
00318       this->SetNEvents(nEv);
00319       this->SetNSigEvents_unweighted(nsig_unweighted);
00320       this->SetNBkgEvents_unweighted(nbkg_unweighted);
00321       this->SetNEvents_unweighted(nEv_unweighted);
00322       this->SetSeparationIndex(separationIndex);
00323       this->SetSeparationGain(separationGain);
00324       this->SetPurity();
00325       //      this->SetResponse(response); old .txt weightfiles don't know regression yet
00326       this->SetCC(cc);
00327    }
00328 
00329    return kTRUE;
00330 }
00331 
00332 //_______________________________________________________________________
00333 void TMVA::DecisionTreeNode::ClearNodeAndAllDaughters()
00334 {
00335    // clear the nodes (their S/N, Nevents etc), just keep the structure of the tree
00336    SetNSigEvents(0);
00337    SetNBkgEvents(0);
00338    SetNEvents(0);
00339    SetNSigEvents_unweighted(0);
00340    SetNBkgEvents_unweighted(0);
00341    SetNEvents_unweighted(0);
00342    SetSeparationIndex(-1);
00343    SetSeparationGain(-1);
00344    SetPurity();
00345 
00346    if (this->GetLeft()  != NULL) ((DecisionTreeNode*)(this->GetLeft()))->ClearNodeAndAllDaughters();
00347    if (this->GetRight() != NULL) ((DecisionTreeNode*)(this->GetRight()))->ClearNodeAndAllDaughters();
00348 }
00349 
00350 //_______________________________________________________________________
00351 void TMVA::DecisionTreeNode::ResetValidationData( ) {
00352    // temporary stored node values (number of events, etc.) that originate
00353    // not from the training but from the validation data (used in pruning)
00354    SetNBValidation( 0.0 );
00355    SetNSValidation( 0.0 );
00356    SetSumTarget( 0 );
00357    SetSumTarget2( 0 );
00358 
00359    if(GetLeft() != NULL && GetRight() != NULL) {
00360       GetLeft()->ResetValidationData();
00361       GetRight()->ResetValidationData();
00362    }
00363 }
00364 
00365 //_______________________________________________________________________
00366 void TMVA::DecisionTreeNode::PrintPrune( ostream& os ) const {
00367    // printout of the node (can be read in with ReadDataRecord)
00368 
00369    os << "----------------------" << std::endl
00370       << "|~T_t| " << GetNTerminal() << std::endl
00371       << "R(t): " << GetNodeR() << std::endl
00372       << "R(T_t): " << GetSubTreeR() << std::endl
00373       << "g(t): " << GetAlpha() << std::endl
00374       << "G(t): "  << GetAlphaMinSubtree() << std::endl;
00375 }
00376 
00377 //_______________________________________________________________________
00378 void TMVA::DecisionTreeNode::PrintRecPrune( ostream& os ) const {
00379    // recursive printout of the node and its daughters
00380 
00381    this->PrintPrune(os);
00382    if(this->GetLeft() != NULL && this->GetRight() != NULL) {
00383       ((DecisionTreeNode*)this->GetLeft())->PrintRecPrune(os);
00384       ((DecisionTreeNode*)this->GetRight())->PrintRecPrune(os);
00385    }
00386 }
00387 
00388 //_______________________________________________________________________
00389 void TMVA::DecisionTreeNode::SetCC(Double_t cc)
00390 {
00391    if (fTrainInfo) fTrainInfo->fCC = cc;
00392    else *fgLogger << kFATAL << "call to SetCC without trainingInfo" << Endl;
00393 }
00394 
00395 //_______________________________________________________________________
00396 Float_t TMVA::DecisionTreeNode::GetSampleMin(UInt_t ivar) const {
00397    // return the minimum of variable ivar from the training sample
00398    // that pass/end up in this node
00399    if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMin[ivar];
00400    else *fgLogger << kFATAL << "You asked for Min of the event sample in node for variable "
00401                  << ivar << " that is out of range" << Endl;
00402    return -9999;
00403 }
00404 
00405 //_______________________________________________________________________
00406 Float_t TMVA::DecisionTreeNode::GetSampleMax(UInt_t ivar) const {
00407    // return the maximum of variable ivar from the training sample
00408    // that pass/end up in this node
00409    if (fTrainInfo && ivar < fTrainInfo->fSampleMin.size()) return fTrainInfo->fSampleMax[ivar];
00410    else *fgLogger << kFATAL << "You asked for Max of the event sample in node for variable "
00411                  << ivar << " that is out of range" << Endl;
00412    return 9999;
00413 }
00414 
00415 //_______________________________________________________________________
00416 void TMVA::DecisionTreeNode::SetSampleMin(UInt_t ivar, Float_t xmin){
00417    // set the minimum of variable ivar from the training sample
00418    // that pass/end up in this node
00419    if ( fTrainInfo) {
00420       if ( ivar >= fTrainInfo->fSampleMin.size()) fTrainInfo->fSampleMin.resize(ivar+1);
00421       fTrainInfo->fSampleMin[ivar]=xmin;
00422    }
00423 }
00424 
00425 //_______________________________________________________________________
00426 void TMVA::DecisionTreeNode::SetSampleMax(UInt_t ivar, Float_t xmax){
00427    // set the maximum of variable ivar from the training sample
00428    // that pass/end up in this node
00429    if( ! fTrainInfo ) return;
00430    if ( ivar >= fTrainInfo->fSampleMax.size() ) 
00431       fTrainInfo->fSampleMax.resize(ivar+1);
00432    fTrainInfo->fSampleMax[ivar]=xmax;
00433 }
00434 
00435 //_______________________________________________________________________
00436 void TMVA::DecisionTreeNode::ReadAttributes(void* node, UInt_t /* tmva_Version_Code */  )
00437 {
00438    Float_t tempNSigEvents,tempNBkgEvents;
00439 
00440    Int_t nCoef;
00441    if (gTools().HasAttr(node, "NCoef")){
00442       gTools().ReadAttr(node, "NCoef",  nCoef                  );
00443       this->SetNFisherCoeff(nCoef);
00444       Double_t tmp;
00445       for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) {
00446          gTools().ReadAttr(node, Form("fC%d",i),  tmp          );
00447          this->SetFisherCoeff(i,tmp);
00448       }
00449    }else{
00450       this->SetNFisherCoeff(0);
00451    }
00452    gTools().ReadAttr(node, "IVar",  fSelector               );
00453    gTools().ReadAttr(node, "Cut",   fCutValue               );
00454    gTools().ReadAttr(node, "cType", fCutType                );               
00455    if (gTools().HasAttr(node,"res")) gTools().ReadAttr(node, "res",   fResponse);
00456    if (gTools().HasAttr(node,"rms")) gTools().ReadAttr(node, "rms",   fRMS);
00457    //   else { 
00458    if( gTools().HasAttr(node, "purity") ) {
00459       gTools().ReadAttr(node, "purity",fPurity );
00460    } else {
00461       gTools().ReadAttr(node, "nS",    tempNSigEvents             );
00462       gTools().ReadAttr(node, "nB",    tempNBkgEvents             );
00463       fPurity = tempNSigEvents / (tempNSigEvents + tempNBkgEvents);
00464    }
00465    //   }
00466    gTools().ReadAttr(node, "nType", fNodeType               );
00467 }
00468 
00469 
00470 //_______________________________________________________________________
00471 void TMVA::DecisionTreeNode::AddAttributesToNode(void* node) const
00472 {
00473    // add attribute to xml
00474    gTools().AddAttr(node, "NCoef", GetNFisherCoeff());
00475    for (Int_t i=0; i< (Int_t) this->GetNFisherCoeff(); i++) 
00476       gTools().AddAttr(node, Form("fC%d",i),  this->GetFisherCoeff(i));
00477 
00478    gTools().AddAttr(node, "IVar",  GetSelector());
00479    gTools().AddAttr(node, "Cut",   GetCutValue());
00480    gTools().AddAttr(node, "cType", GetCutType());
00481 
00482    //UInt_t analysisType = (dynamic_cast<const TMVA::DecisionTree*>(GetParentTree()) )->GetAnalysisType();
00483    //   if ( analysisType == TMVA::Types:: kRegression) {
00484    gTools().AddAttr(node, "res",   GetResponse());
00485    gTools().AddAttr(node, "rms",   GetRMS());
00486    //} else if ( analysisType == TMVA::Types::kClassification) {
00487    gTools().AddAttr(node, "purity",GetPurity());
00488    //}
00489    gTools().AddAttr(node, "nType", GetNodeType());
00490 }
00491 
00492 //_______________________________________________________________________
00493 void  TMVA::DecisionTreeNode::SetFisherCoeff(Int_t ivar, Double_t coeff)
00494 {
00495    // set fisher coefficients
00496    if ((Int_t) fFisherCoeff.size()<ivar+1) fFisherCoeff.resize(ivar+1) ; 
00497    fFisherCoeff[ivar]=coeff;      
00498 }
00499 
00500 //_______________________________________________________________________
00501 void TMVA::DecisionTreeNode::AddContentToNode( std::stringstream& /*s*/ ) const
00502 {
00503    // adding attributes to tree node  (well, was used in BinarySearchTree,
00504    // and somehow I guess someone programmed it such that we need this in
00505    // this tree too, although we don't..)
00506 }
00507 
00508 //_______________________________________________________________________
00509 void TMVA::DecisionTreeNode::ReadContent( std::stringstream& /*s*/ )
00510 {
00511    // reading attributes from tree node  (well, was used in BinarySearchTree,
00512    // and somehow I guess someone programmed it such that we need this in
00513    // this tree too, although we don't..)
00514 }

Generated on Tue Jul 5 15:16:47 2011 for ROOT_528-00b_version by  doxygen 1.5.1