CCTreeWrapper.h

Go to the documentation of this file.
00001 
00002 /**********************************************************************************
00003  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00004  * Package: TMVA                                                                  *
00005  * Class  : CCTreeWrapper                                                         *
00006  * Web    : http://tmva.sourceforge.net                                           *
00007  *                                                                                *
00008  * Description: a light wrapper of a decision tree, used to perform cost          *
00009  *              complexity pruning "in-place" Cost Complexity Pruning             *
00010  *                                                                                *  
00011  * Author: Doug Schouten (dschoute@sfu.ca)                                        *
00012  *                                                                                *
00013  *                                                                                *
00014  * Copyright (c) 2007:                                                            *
00015  *      CERN, Switzerland                                                         *
00016  *      MPI-K Heidelberg, Germany                                                 *
00017  *      U. of Texas at Austin, USA                                                *
00018  *                                                                                *
00019  * Redistribution and use in source and binary forms, with or without             *
00020  * modification, are permitted according to the terms listed in LICENSE           *
00021  * (http://tmva.sourceforge.net/LICENSE)                                          *
00022  **********************************************************************************/
00023 
00024 #ifndef ROOT_TMVA_CCTreeWrapper
00025 #define ROOT_TMVA_CCTreeWrapper
00026 
00027 #ifndef ROOT_TMVA_Event
00028 #include "TMVA/Event.h"
00029 #endif
00030 #ifndef ROOT_TMVA_SeparationBase
00031 #include "TMVA/SeparationBase.h"
00032 #endif
00033 #ifndef ROOT_TMVA_DecisionTree
00034 #include "TMVA/DecisionTree.h"
00035 #endif
00036 #ifndef ROOT_TMVA_DataSet
00037 #include "TMVA/DataSet.h"
00038 #endif
00039 #ifndef ROOT_TMVA_Version
00040 #include "TMVA/Version.h"
00041 #endif
00042 
00043 
00044 namespace TMVA {
00045 
00046    class CCTreeWrapper {
00047 
00048    public:
00049 
00050       typedef std::vector<Event*> EventList;
00051 
00052       /////////////////////////////////////////////////////////////
00053       // CCTreeNode - a light wrapper of a decision tree node    //
00054       //                                                         //
00055       /////////////////////////////////////////////////////////////
00056 
00057       class CCTreeNode : virtual public Node {
00058 
00059       public:
00060 
00061          CCTreeNode( DecisionTreeNode* n = NULL );
00062          virtual ~CCTreeNode( );
00063       
00064          virtual Node* CreateNode() const { return new CCTreeNode(); }
00065 
00066          // set |~T_t|, the number of terminal descendants of node t 
00067          inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); }
00068 
00069          // return |~T_t|
00070          inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; }
00071 
00072          // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t
00073          inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); }
00074       
00075          // return R(t) for node t
00076          inline Double_t GetNodeResubstitutionEstimate( ) const { return fNodeResubstitutionEstimate; }
00077 
00078          // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at
00079          // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree)
00080          inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ?  R : 0.0); }
00081       
00082          // return R(T_t) for node t
00083          inline Double_t GetResubstitutionEstimate( ) const { return fResubstitutionEstimate; }
00084       
00085          // set the critical point of alpha
00086          //             R(t) - R(T_t)
00087          //  alpha_c <  ------------- := g(t)
00088          //              |~T_t| - 1
00089          // which is the value of alpha such that the branch rooted at node t is pruned
00090          inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; }
00091 
00092          // get the critical alpha value for this node
00093          inline Double_t GetAlphaC( ) const { return fAlphaC; }
00094 
00095          // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) )
00096          inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; }
00097 
00098          // get the minimum critical alpha value 
00099          inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; }
00100 
00101          // get the pointer to the wrapped DT node
00102          inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; }
00103 
00104          // get pointers to children, mother in the CC tree
00105          inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); }
00106          inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); }
00107          inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); }
00108 
00109          // printout of the node (can be read in with ReadDataRecord)
00110          virtual void Print( ostream& os ) const;
00111 
00112          // recursive printout of the node and its daughters 
00113          virtual void PrintRec ( ostream& os ) const;
00114 
00115          virtual void AddAttributesToNode(void* node) const;
00116          virtual void AddContentToNode(std::stringstream& s) const;
00117          
00118 
00119          // test event if it decends the tree at this node to the right  
00120          inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ? 
00121                                                                            GetDTNode()->GoesRight(e) : false); }
00122       
00123          // test event if it decends the tree at this node to the left 
00124          inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ? 
00125                                                                            GetDTNode()->GoesLeft(e) : false); }
00126       
00127       private:
00128 
00129          // initialize a node from a data record
00130          virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
00131          virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE );
00132          virtual void ReadContent(std::stringstream& s);
00133          
00134          Int_t fNLeafDaughters; //! number of terminal descendants
00135          Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t
00136          Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) }
00137          Double_t fAlphaC; //! critical point, g(t) = alpha_c(t)
00138          Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants
00139          DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree
00140       };
00141 
00142       CCTreeWrapper( DecisionTree* T,  SeparationBase* qualityIndex );
00143       ~CCTreeWrapper( );
00144 
00145       // return the decision tree output for an event 
00146       Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false );
00147       // return the misclassification rate of a pruned tree for a validation event sample
00148       Double_t TestTreeQuality( const EventList* validationSample );
00149       Double_t TestTreeQuality( const DataSet* validationSample );
00150 
00151       // remove the branch rooted at node t
00152       void PruneNode( CCTreeNode* t );
00153       // initialize the node t and all its descendants
00154       void InitTree( CCTreeNode* t );
00155 
00156       // return the root node for this tree
00157       CCTreeNode* GetRoot() { return fRoot; }
00158    private:
00159       SeparationBase* fQualityIndex;  //! pointer to the used quality index calculator
00160       DecisionTree* fDTParent;        //! pointer to underlying DecisionTree
00161       CCTreeNode* fRoot;              //! the root node of the (wrapped) decision Tree
00162    };
00163 
00164 }
00165 
00166 #endif
00167 
00168 
00169 

Generated on Tue Jul 5 14:27:19 2011 for ROOT_528-00b_version by  doxygen 1.5.1