DecisionTreeNode.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DecisionTreeNode.h 37986 2011-02-04 21:42:15Z pcanal $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss, Eckhard von Toerne
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : DecisionTreeNode                                                      *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Node for the Decision Tree                                                *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00017  *      Eckhard von Toerne <evt@physik.uni-bonn.de>  - U. of Bonn, Germany        *
00018  *                                                                                *
00019  * Copyright (c) 2009:                                                            *
00020  *      CERN, Switzerland                                                         *
00021  *      U. of Victoria, Canada                                                    *
00022  *      MPI-K Heidelberg, Germany                                                 *
00023 *       U. of Bonn, Germany                                                       *
00024  *                                                                                *
00025  * Redistribution and use in source and binary forms, with or without             *
00026  * modification, are permitted according to the terms listed in LICENSE           *
00027  * (http://tmva.sourceforge.net/LICENSE)                                          *
00028  **********************************************************************************/
00029 
00030 #ifndef ROOT_TMVA_DecisionTreeNode
00031 #define ROOT_TMVA_DecisionTreeNode
00032 
00033 //////////////////////////////////////////////////////////////////////////
00034 //                                                                      //
00035 // DecisionTreeNode                                                     //
00036 //                                                                      //
00037 // Node for the Decision Tree                                           //
00038 //                                                                      //
00039 //////////////////////////////////////////////////////////////////////////
00040 
00041 #ifndef ROOT_TMVA_Node
00042 #include "TMVA/Node.h"
00043 #endif
00044 
00045 #ifndef ROOT_TMVA_Version
00046 #include "TMVA/Version.h"
00047 #endif
00048 
00049 #include <iostream>
00050 #include <vector>
00051 #include <map>
00052 namespace TMVA {
00053 
00054    class DTNodeTrainingInfo
00055    {
00056    public:
00057       DTNodeTrainingInfo():fSampleMin(), 
00058                            fSampleMax(), 
00059                            fNodeR(0),fSubTreeR(0),fAlpha(0),fG(0),fNTerminal(0),
00060                            fNB(0),fNS(0),fSumTarget(0),fSumTarget2(0),fCC(0), 
00061                            fNSigEvents ( 0 ), fNBkgEvents ( 0 ),
00062                            fNEvents ( -1 ),
00063                            fNSigEvents_unweighted ( 0 ),
00064                            fNBkgEvents_unweighted ( 0 ),
00065                            fNEvents_unweighted ( 0 ),
00066                            fSeparationIndex (-1 ),
00067                            fSeparationGain ( -1 )
00068       {
00069       }
00070       std::vector< Float_t >  fSampleMin; // the minima for each ivar of the sample on the node during training
00071       std::vector< Float_t >  fSampleMax; // the maxima for each ivar of the sample on the node during training
00072       Double_t fNodeR;           // node resubstitution estimate, R(t)
00073       Double_t fSubTreeR;        // R(T) = Sum(R(t) : t in ~T)
00074       Double_t fAlpha;           // critical alpha for this node
00075       Double_t fG;               // minimum alpha in subtree rooted at this node
00076       Int_t    fNTerminal;       // number of terminal nodes in subtree rooted at this node
00077       Double_t fNB;              // sum of weights of background events from the pruning sample in this node
00078       Double_t fNS;              // ditto for the signal events
00079       Float_t  fSumTarget;       // sum of weight*target  used for the calculatio of the variance (regression)
00080       Float_t  fSumTarget2;      // sum of weight*target^2 used for the calculatio of the variance (regression)
00081       Double_t fCC;  // debug variable for cost complexity pruning ..
00082 
00083       Float_t  fNSigEvents;      // sum of weights of signal event in the node
00084       Float_t  fNBkgEvents;      // sum of weights of backgr event in the node
00085       Float_t  fNEvents;         // number of events in that entered the node (during training)
00086       Float_t  fNSigEvents_unweighted;      // sum of signal event in the node
00087       Float_t  fNBkgEvents_unweighted;      // sum of backgr event in the node
00088       Float_t  fNEvents_unweighted;         // number of events in that entered the node (during training)
00089       Float_t  fSeparationIndex; // measure of "purity" (separation between S and B) AT this node
00090       Float_t  fSeparationGain;  // measure of "purity", separation, or information gained BY this nodes selection
00091 
00092       // copy constructor
00093       DTNodeTrainingInfo(const DTNodeTrainingInfo& n) :
00094          fSampleMin(),fSampleMax(), // Samplemin and max are reset in copy constructor
00095          fNodeR(n.fNodeR), fSubTreeR(n.fSubTreeR),
00096          fAlpha(n.fAlpha), fG(n.fG),
00097          fNTerminal(n.fNTerminal),
00098          fNB(n.fNB), fNS(n.fNS),
00099          fSumTarget(0),fSumTarget2(0), // SumTarget reset in copy constructor
00100          fCC(0),
00101          fNSigEvents ( n.fNSigEvents ), fNBkgEvents ( n.fNBkgEvents ),
00102          fNEvents ( n.fNEvents ),
00103          fNSigEvents_unweighted ( n.fNSigEvents_unweighted ),
00104          fNBkgEvents_unweighted ( n.fNBkgEvents_unweighted ),
00105          fNEvents_unweighted ( n.fNEvents_unweighted ),
00106          fSeparationIndex( n.fSeparationIndex ),
00107          fSeparationGain ( n.fSeparationGain )
00108       { }
00109    };
00110 
00111    class Event;
00112    class MsgLogger;
00113 
00114    class DecisionTreeNode: public Node {
00115 
00116    public:
00117 
00118       // constructor of an essentially "empty" node floating in space
00119       DecisionTreeNode ();
00120       // constructor of a daughter node as a daughter of 'p'
00121       DecisionTreeNode (Node* p, char pos); 
00122     
00123       // copy constructor 
00124       DecisionTreeNode (const DecisionTreeNode &n, DecisionTreeNode* parent = NULL); 
00125       
00126       // destructor
00127       virtual ~DecisionTreeNode();
00128 
00129       virtual Node* CreateNode() const { return new DecisionTreeNode(); }
00130 
00131       inline void SetNFisherCoeff(Int_t nvars){fFisherCoeff.resize(nvars);}
00132       inline UInt_t GetNFisherCoeff() const { return fFisherCoeff.size();}
00133       // set fisher coefficients
00134       void SetFisherCoeff(Int_t ivar, Double_t coeff);      
00135       // get fisher coefficients
00136       Double_t GetFisherCoeff(Int_t ivar) const {return fFisherCoeff.at(ivar);}
00137 
00138       // test event if it decends the tree at this node to the right
00139       virtual Bool_t GoesRight( const Event & ) const;
00140 
00141       // test event if it decends the tree at this node to the left
00142       virtual Bool_t GoesLeft ( const Event & ) const;
00143 
00144       // set index of variable used for discrimination at this node
00145       void SetSelector( Short_t i) { fSelector = i; }
00146       // return index of variable used for discrimination at this node
00147       Short_t GetSelector() const { return fSelector; }
00148 
00149       // set the cut value applied at this node
00150       void  SetCutValue ( Float_t c ) { fCutValue  = c; }
00151       // return the cut value applied at this node
00152       Float_t GetCutValue ( void ) const { return fCutValue;  }
00153 
00154       // set true: if event variable > cutValue ==> signal , false otherwise
00155       void SetCutType( Bool_t t   ) { fCutType = t; }
00156       // return kTRUE: Cuts select signal, kFALSE: Cuts select bkg
00157       Bool_t GetCutType( void ) const { return fCutType; }
00158 
00159       // set node type: 1 signal node, -1 bkg leave, 0 intermediate Node
00160       void  SetNodeType( Int_t t ) { fNodeType = t;}
00161       // return node type: 1 signal node, -1 bkg leave, 0 intermediate Node
00162       Int_t GetNodeType( void ) const { return fNodeType; }
00163 
00164       //return  S/(S+B) (purity) at this node (from  training)
00165       Float_t GetPurity( void ) const { return fPurity;}
00166       //calculate S/(S+B) (purity) at this node (from  training)
00167       void SetPurity( void );
00168 
00169       //set the response of the node (for regression)
00170       void SetResponse( Float_t r ) { fResponse = r;}
00171 
00172       //return the response of the node (for regression)
00173       Float_t GetResponse( void ) const { return fResponse;}
00174 
00175       //set the RMS of the response of the node (for regression)
00176       void SetRMS( Float_t r ) { fRMS = r;}
00177 
00178       //return the RMS of the response of the node (for regression)
00179       Float_t GetRMS( void ) const { return fRMS;}
00180 
00181       // set the sum of the signal weights in the node
00182       void SetNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents = s; }
00183 
00184       // set the sum of the backgr weights in the node
00185       void SetNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents = b; }
00186 
00187       // set the number of events that entered the node (during training)
00188       void SetNEvents( Float_t nev ){ fTrainInfo->fNEvents =nev ; }
00189 
00190       // set the sum of the unweighted signal events in the node
00191       void SetNSigEvents_unweighted( Float_t s ) { fTrainInfo->fNSigEvents_unweighted = s; }
00192 
00193       // set the sum of the unweighted backgr events in the node
00194       void SetNBkgEvents_unweighted( Float_t b ) { fTrainInfo->fNBkgEvents_unweighted = b; }
00195 
00196       // set the number of unweighted events that entered the node (during training)
00197       void SetNEvents_unweighted( Float_t nev ){ fTrainInfo->fNEvents_unweighted =nev ; }
00198 
00199       // increment the sum of the signal weights in the node
00200       void IncrementNSigEvents( Float_t s ) { fTrainInfo->fNSigEvents += s; }
00201 
00202       // increment the sum of the backgr weights in the node
00203       void IncrementNBkgEvents( Float_t b ) { fTrainInfo->fNBkgEvents += b; }
00204 
00205       // increment the number of events that entered the node (during training)
00206       void IncrementNEvents( Float_t nev ){ fTrainInfo->fNEvents +=nev ; }
00207 
00208       // increment the sum of the signal weights in the node
00209       void IncrementNSigEvents_unweighted( ) { fTrainInfo->fNSigEvents_unweighted += 1; }
00210 
00211       // increment the sum of the backgr weights in the node
00212       void IncrementNBkgEvents_unweighted( ) { fTrainInfo->fNBkgEvents_unweighted += 1; }
00213 
00214       // increment the number of events that entered the node (during training)
00215       void IncrementNEvents_unweighted( ){ fTrainInfo->fNEvents_unweighted +=1 ; }
00216 
00217       // return the sum of the signal weights in the node
00218       Float_t GetNSigEvents( void ) const  { return fTrainInfo->fNSigEvents; }
00219 
00220       // return the sum of the backgr weights in the node
00221       Float_t GetNBkgEvents( void ) const  { return fTrainInfo->fNBkgEvents; }
00222 
00223       // return  the number of events that entered the node (during training)
00224       Float_t GetNEvents( void ) const  { return fTrainInfo->fNEvents; }
00225 
00226       // return the sum of unweighted signal weights in the node
00227       Float_t GetNSigEvents_unweighted( void ) const  { return fTrainInfo->fNSigEvents_unweighted; }
00228 
00229       // return the sum of unweighted backgr weights in the node
00230       Float_t GetNBkgEvents_unweighted( void ) const  { return fTrainInfo->fNBkgEvents_unweighted; }
00231 
00232       // return  the number of unweighted events that entered the node (during training)
00233       Float_t GetNEvents_unweighted( void ) const  { return fTrainInfo->fNEvents_unweighted; }
00234 
00235 
00236       // set the choosen index, measure of "purity" (separation between S and B) AT this node
00237       void SetSeparationIndex( Float_t sep ){ fTrainInfo->fSeparationIndex =sep ; }
00238       // return the separation index AT this node
00239       Float_t GetSeparationIndex( void ) const  { return fTrainInfo->fSeparationIndex; }
00240 
00241       // set the separation, or information gained BY this nodes selection
00242       void SetSeparationGain( Float_t sep ){ fTrainInfo->fSeparationGain =sep ; }
00243       // return the gain in separation obtained by this nodes selection
00244       Float_t GetSeparationGain( void ) const  { return fTrainInfo->fSeparationGain; }
00245 
00246       // printout of the node
00247       virtual void Print( ostream& os ) const;
00248 
00249       // recursively print the node and its daughters (--> print the 'tree')
00250       virtual void PrintRec( ostream&  os ) const;
00251 
00252       virtual void AddAttributesToNode(void* node) const;
00253       virtual void AddContentToNode(std::stringstream& s) const;
00254 
00255       // recursively clear the nodes content (S/N etc, but not the cut criteria)
00256       void ClearNodeAndAllDaughters();
00257 
00258       // get pointers to children, mother in the tree
00259 
00260       // return pointer to the left/right daughter or parent node
00261       inline virtual DecisionTreeNode* GetLeft( )   const { return dynamic_cast<DecisionTreeNode*>(fLeft); }
00262       inline virtual DecisionTreeNode* GetRight( )  const { return dynamic_cast<DecisionTreeNode*>(fRight); }
00263       inline virtual DecisionTreeNode* GetParent( ) const { return dynamic_cast<DecisionTreeNode*>(fParent); }
00264 
00265       // set pointer to the left/right daughter and parent node
00266       inline virtual void SetLeft  (Node* l) { fLeft   = dynamic_cast<DecisionTreeNode*>(l);} 
00267       inline virtual void SetRight (Node* r) { fRight  = dynamic_cast<DecisionTreeNode*>(r);} 
00268       inline virtual void SetParent(Node* p) { fParent = dynamic_cast<DecisionTreeNode*>(p);} 
00269 
00270 
00271 
00272 
00273       // the node resubstitution estimate, R(t), for Cost Complexity pruning
00274       inline void SetNodeR( Double_t r ) { fTrainInfo->fNodeR = r;    }
00275       inline Double_t GetNodeR( ) const  { return fTrainInfo->fNodeR; }
00276 
00277       // the resubstitution estimate, R(T_t), of the tree rooted at this node
00278       inline void SetSubTreeR( Double_t r ) { fTrainInfo->fSubTreeR = r;    }
00279       inline Double_t GetSubTreeR( ) const  { return fTrainInfo->fSubTreeR; }
00280 
00281       //                             R(t) - R(T_t)
00282       // the critical point alpha =  -------------
00283       //                              |~T_t| - 1
00284       inline void SetAlpha( Double_t alpha ) { fTrainInfo->fAlpha = alpha; }
00285       inline Double_t GetAlpha( ) const      { return fTrainInfo->fAlpha;  }
00286 
00287       // the minimum alpha in the tree rooted at this node
00288       inline void SetAlphaMinSubtree( Double_t g ) { fTrainInfo->fG = g;    }
00289       inline Double_t GetAlphaMinSubtree( ) const  { return fTrainInfo->fG; }
00290 
00291       // number of terminal nodes in the subtree rooted here
00292       inline void SetNTerminal( Int_t n ) { fTrainInfo->fNTerminal = n;    }
00293       inline Int_t GetNTerminal( ) const  { return fTrainInfo->fNTerminal; }
00294 
00295       // number of background/signal events from the pruning validation sample
00296       inline void SetNBValidation( Double_t b ) { fTrainInfo->fNB = b; }
00297       inline void SetNSValidation( Double_t s ) { fTrainInfo->fNS = s; }
00298       inline Double_t GetNBValidation( ) const  { return fTrainInfo->fNB; }
00299       inline Double_t GetNSValidation( ) const  { return fTrainInfo->fNS; }
00300 
00301 
00302       inline void SetSumTarget(Float_t t)  {fTrainInfo->fSumTarget = t; }
00303       inline void SetSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 = t2; }
00304 
00305       inline void AddToSumTarget(Float_t t)  {fTrainInfo->fSumTarget += t; }
00306       inline void AddToSumTarget2(Float_t t2){fTrainInfo->fSumTarget2 += t2; }
00307 
00308       inline Float_t GetSumTarget()  const {return fTrainInfo? fTrainInfo->fSumTarget : -9999;}
00309       inline Float_t GetSumTarget2() const {return fTrainInfo? fTrainInfo->fSumTarget2: -9999;}
00310 
00311 
00312       // reset the pruning validation data
00313       void ResetValidationData( );
00314 
00315       // flag indicates whether this node is terminal
00316       inline Bool_t IsTerminal() const            { return fIsTerminalNode; }
00317       inline void SetTerminal( Bool_t s = kTRUE ) { fIsTerminalNode = s;    }
00318       void PrintPrune( ostream& os ) const ;
00319       void PrintRecPrune( ostream& os ) const;
00320 
00321       void     SetCC(Double_t cc);
00322       Double_t GetCC() const {return (fTrainInfo? fTrainInfo->fCC : -1.);}
00323 
00324       Float_t GetSampleMin(UInt_t ivar) const;
00325       Float_t GetSampleMax(UInt_t ivar) const;
00326       void     SetSampleMin(UInt_t ivar, Float_t xmin);
00327       void     SetSampleMax(UInt_t ivar, Float_t xmax);
00328 
00329       static bool fgIsTraining; // static variable to flag training phase in which we need fTrainInfo
00330 
00331    protected:
00332 
00333       static MsgLogger* fgLogger;    // static because there is a huge number of nodes...
00334 
00335       std::vector<Double_t>       fFisherCoeff;    // the other fisher coeff (offset at the last element
00336 
00337       Float_t  fCutValue;        // cut value appplied on this node to discriminate bkg against sig
00338       Bool_t   fCutType;         // true: if event variable > cutValue ==> signal , false otherwise
00339       Short_t  fSelector;        // index of variable used in node selection (decision tree)
00340 
00341       Float_t  fResponse;        // response value in case of regression
00342       Float_t  fRMS;             // response RMS of the regression node
00343       Int_t    fNodeType;        // Type of node: -1 == Bkg-leaf, 1 == Signal-leaf, 0 = internal
00344       Float_t  fPurity;          // the node purity
00345 
00346       Bool_t   fIsTerminalNode;    //! flag to set node as terminal (i.e., without deleting its descendants)
00347 
00348       mutable DTNodeTrainingInfo* fTrainInfo;
00349 
00350    private:
00351 
00352       virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
00353       virtual Bool_t ReadDataRecord( istream& is, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
00354       virtual void ReadContent(std::stringstream& s);
00355 
00356       ClassDef(DecisionTreeNode,0) // Node for the Decision Tree 
00357    };
00358 } // namespace TMVA
00359 
00360 #endif

Generated on Tue Jul 5 14:27:27 2011 for ROOT_528-00b_version by  doxygen 1.5.1