00001 #ifndef ROOT_TMVA_CCPruner 00002 #define ROOT_TMVA_CCPruner 00003 /********************************************************************************** 00004 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00005 * Package: TMVA * 00006 * Class : CCPruner * 00007 * Web : http://tmva.sourceforge.net * 00008 * * 00009 * Description: Cost Complexity Pruning * 00010 * 00011 * Author: Doug Schouten (dschoute@sfu.ca) 00012 * 00013 * * 00014 * Copyright (c) 2007: * 00015 * CERN, Switzerland * 00016 * MPI-K Heidelberg, Germany * 00017 * U. of Texas at Austin, USA * 00018 * * 00019 * Redistribution and use in source and binary forms, with or without * 00020 * modification, are permitted according to the terms listed in LICENSE * 00021 * (http://tmva.sourceforge.net/LICENSE) * 00022 **********************************************************************************/ 00023 00024 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 00025 // CCPruner - a helper class to prune a decision tree using the Cost Complexity method // 00026 // (see Classification and Regression Trees by Leo Breiman et al) // 00027 // // 00028 // Some definitions: // 00029 // // 00030 // T_max - the initial, usually highly overtrained tree, that is to be pruned back // 00031 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T // 00032 // ~T - set of terminal nodes in T // 00033 // T' - the pruned subtree of T_max that has the best quality index R(T') // 00034 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|) // 00035 // // 00036 // There are two running modes in CCPruner: (i) one may select a prune strength and prune back // 00037 // the tree T_max until the criterion // 00038 // R(T) - R(t) // 00039 // alpha < ---------- // 00040 // |~T_t| - 1 // 00041 // // 00042 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points // 00043 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned // 00044 // subtree, defined to be the subtree with the best quality index for the validation sample. // 00045 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 00046 00047 00048 #ifndef ROOT_TMVA_DecisionTree 00049 #include "TMVA/DecisionTree.h" 00050 #endif 00051 00052 /* #ifndef ROOT_TMVA_DecisionTreeNode */ 00053 /* #include "TMVA/DecisionTreeNode.h" */ 00054 /* #endif */ 00055 00056 #ifndef ROOT_TMVA_Event 00057 #include "TMVA/Event.h" 00058 #endif 00059 00060 namespace TMVA { 00061 class DecisionTreeNode; 00062 class SeparationBase; 00063 00064 class CCPruner { 00065 public: 00066 typedef std::vector<Event*> EventList; 00067 00068 CCPruner( DecisionTree* t_max, 00069 const EventList* validationSample, 00070 SeparationBase* qualityIndex = NULL ); 00071 00072 CCPruner( DecisionTree* t_max, 00073 const DataSet* validationSample, 00074 SeparationBase* qualityIndex = NULL ); 00075 00076 ~CCPruner( ); 00077 00078 // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated) 00079 void SetPruneStrength( Float_t alpha = -1.0 ); 00080 00081 void Optimize( ); 00082 00083 // return the list of pruning locations to define the optimal subtree T' of T_max 00084 std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const; 00085 00086 // return the quality index from the validation sample for the optimal subtree T' 00087 inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ? 00088 fQualityIndexList[fOptimalK] : -1.0); } 00089 00090 // return the prune strength (=alpha) corresponding to the prune sequence 00091 inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ? 00092 fPruneStrengthList[fOptimalK] : -1.0); } 00093 00094 private: 00095 Float_t fAlpha; //! regularization parameter in CC pruning 00096 const EventList* fValidationSample; //! the event sample to select the optimally-pruned tree 00097 const DataSet* fValidationDataSet; //! the event sample to select the optimally-pruned tree 00098 SeparationBase* fQualityIndex; //! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) } 00099 Bool_t fOwnQIndex; //! flag indicates if fQualityIndex is owned by this 00100 00101 DecisionTree* fTree; //! (pruned) decision tree 00102 00103 std::vector<TMVA::DecisionTreeNode*> fPruneSequence; //! map of weakest links (i.e., branches to prune) -> pruning index 00104 std::vector<Float_t> fPruneStrengthList; //! map of alpha -> pruning index 00105 std::vector<Float_t> fQualityIndexList; //! map of R(T) -> pruning index 00106 00107 Int_t fOptimalK; //! index of the optimal tree in the pruned tree sequence 00108 Bool_t fDebug; //! debug flag 00109 }; 00110 } 00111 00112 inline void TMVA::CCPruner::SetPruneStrength( Float_t alpha ) { 00113 fAlpha = (alpha > 0 ? alpha : 0.0); 00114 } 00115 00116 00117 #endif 00118 00119