00001 00002 /********************************************************************************** 00003 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00004 * Package: TMVA * 00005 * Class : CCTreeWrapper * 00006 * Web : http://tmva.sourceforge.net * 00007 * * 00008 * Description: a light wrapper of a decision tree, used to perform cost * 00009 * complexity pruning "in-place" Cost Complexity Pruning * 00010 * * 00011 * Author: Doug Schouten (dschoute@sfu.ca) * 00012 * * 00013 * * 00014 * Copyright (c) 2007: * 00015 * CERN, Switzerland * 00016 * MPI-K Heidelberg, Germany * 00017 * U. of Texas at Austin, USA * 00018 * * 00019 * Redistribution and use in source and binary forms, with or without * 00020 * modification, are permitted according to the terms listed in LICENSE * 00021 * (http://tmva.sourceforge.net/LICENSE) * 00022 **********************************************************************************/ 00023 00024 #ifndef ROOT_TMVA_CCTreeWrapper 00025 #define ROOT_TMVA_CCTreeWrapper 00026 00027 #ifndef ROOT_TMVA_Event 00028 #include "TMVA/Event.h" 00029 #endif 00030 #ifndef ROOT_TMVA_SeparationBase 00031 #include "TMVA/SeparationBase.h" 00032 #endif 00033 #ifndef ROOT_TMVA_DecisionTree 00034 #include "TMVA/DecisionTree.h" 00035 #endif 00036 #ifndef ROOT_TMVA_DataSet 00037 #include "TMVA/DataSet.h" 00038 #endif 00039 #ifndef ROOT_TMVA_Version 00040 #include "TMVA/Version.h" 00041 #endif 00042 00043 00044 namespace TMVA { 00045 00046 class CCTreeWrapper { 00047 00048 public: 00049 00050 typedef std::vector<Event*> EventList; 00051 00052 ///////////////////////////////////////////////////////////// 00053 // CCTreeNode - a light wrapper of a decision tree node // 00054 // // 00055 ///////////////////////////////////////////////////////////// 00056 00057 class CCTreeNode : virtual public Node { 00058 00059 public: 00060 00061 CCTreeNode( DecisionTreeNode* n = NULL ); 00062 virtual ~CCTreeNode( ); 00063 00064 virtual Node* CreateNode() const { return new CCTreeNode(); } 00065 00066 // set |~T_t|, the number of terminal descendants of node t 00067 inline void SetNLeafDaughters( Int_t N ) { fNLeafDaughters = (N > 0 ? N : 0); } 00068 00069 // return |~T_t| 00070 inline Int_t GetNLeafDaughters() const { return fNLeafDaughters; } 00071 00072 // set R(t), the node resubstitution estimate (Gini, misclassification, etc.) for the node t 00073 inline void SetNodeResubstitutionEstimate( Double_t R ) { fNodeResubstitutionEstimate = (R >= 0 ? R : 0.0); } 00074 00075 // return R(t) for node t 00076 inline Double_t GetNodeResubstitutionEstimate( ) const { return fNodeResubstitutionEstimate; } 00077 00078 // set R(T_t) = sum[t' in ~T_t]{ R(t) }, the resubstitution estimate for the branch rooted at 00079 // node t (it is an estimate because it is calculated from the training dataset, i.e., the original tree) 00080 inline void SetResubstitutionEstimate( Double_t R ) { fResubstitutionEstimate = (R >= 0 ? R : 0.0); } 00081 00082 // return R(T_t) for node t 00083 inline Double_t GetResubstitutionEstimate( ) const { return fResubstitutionEstimate; } 00084 00085 // set the critical point of alpha 00086 // R(t) - R(T_t) 00087 // alpha_c < ------------- := g(t) 00088 // |~T_t| - 1 00089 // which is the value of alpha such that the branch rooted at node t is pruned 00090 inline void SetAlphaC( Double_t alpha ) { fAlphaC = alpha; } 00091 00092 // get the critical alpha value for this node 00093 inline Double_t GetAlphaC( ) const { return fAlphaC; } 00094 00095 // set the minimum critical alpha value for descendants of node t ( G(t) = min(alpha_c, g(t_l), g(t_r)) ) 00096 inline void SetMinAlphaC( Double_t alpha ) { fMinAlphaC = alpha; } 00097 00098 // get the minimum critical alpha value 00099 inline Double_t GetMinAlphaC( ) const { return fMinAlphaC; } 00100 00101 // get the pointer to the wrapped DT node 00102 inline DecisionTreeNode* GetDTNode( ) const { return fDTNode; } 00103 00104 // get pointers to children, mother in the CC tree 00105 inline CCTreeNode* GetLeftDaughter( ) { return dynamic_cast<CCTreeNode*>(GetLeft()); } 00106 inline CCTreeNode* GetRightDaughter( ) { return dynamic_cast<CCTreeNode*>(GetRight()); } 00107 inline CCTreeNode* GetMother( ) { return dynamic_cast<CCTreeNode*>(GetParent()); } 00108 00109 // printout of the node (can be read in with ReadDataRecord) 00110 virtual void Print( ostream& os ) const; 00111 00112 // recursive printout of the node and its daughters 00113 virtual void PrintRec ( ostream& os ) const; 00114 00115 virtual void AddAttributesToNode(void* node) const; 00116 virtual void AddContentToNode(std::stringstream& s) const; 00117 00118 00119 // test event if it decends the tree at this node to the right 00120 inline virtual Bool_t GoesRight( const Event& e ) const { return (GetDTNode() != NULL ? 00121 GetDTNode()->GoesRight(e) : false); } 00122 00123 // test event if it decends the tree at this node to the left 00124 inline virtual Bool_t GoesLeft ( const Event& e ) const { return (GetDTNode() != NULL ? 00125 GetDTNode()->GoesLeft(e) : false); } 00126 00127 private: 00128 00129 // initialize a node from a data record 00130 virtual void ReadAttributes(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE); 00131 virtual Bool_t ReadDataRecord( std::istream& in, UInt_t tmva_Version_Code = TMVA_VERSION_CODE ); 00132 virtual void ReadContent(std::stringstream& s); 00133 00134 Int_t fNLeafDaughters; //! number of terminal descendants 00135 Double_t fNodeResubstitutionEstimate; //! R(t) = misclassification rate for node t 00136 Double_t fResubstitutionEstimate; //! R(T_t) = sum[t' in ~T_t]{ R(t) } 00137 Double_t fAlphaC; //! critical point, g(t) = alpha_c(t) 00138 Double_t fMinAlphaC; //! G(t), minimum critical point of t and its descendants 00139 DecisionTreeNode* fDTNode; //! pointer to wrapped node in the decision tree 00140 }; 00141 00142 CCTreeWrapper( DecisionTree* T, SeparationBase* qualityIndex ); 00143 ~CCTreeWrapper( ); 00144 00145 // return the decision tree output for an event 00146 Double_t CheckEvent( const TMVA::Event & e, Bool_t useYesNoLeaf = false ); 00147 // return the misclassification rate of a pruned tree for a validation event sample 00148 Double_t TestTreeQuality( const EventList* validationSample ); 00149 Double_t TestTreeQuality( const DataSet* validationSample ); 00150 00151 // remove the branch rooted at node t 00152 void PruneNode( CCTreeNode* t ); 00153 // initialize the node t and all its descendants 00154 void InitTree( CCTreeNode* t ); 00155 00156 // return the root node for this tree 00157 CCTreeNode* GetRoot() { return fRoot; } 00158 private: 00159 SeparationBase* fQualityIndex; //! pointer to the used quality index calculator 00160 DecisionTree* fDTParent; //! pointer to underlying DecisionTree 00161 CCTreeNode* fRoot; //! the root node of the (wrapped) decision Tree 00162 }; 00163 00164 } 00165 00166 #endif 00167 00168 00169