CCPruner.h

Go to the documentation of this file.
00001 #ifndef ROOT_TMVA_CCPruner
00002 #define ROOT_TMVA_CCPruner
00003 /**********************************************************************************
00004  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00005  * Package: TMVA                                                                  *
00006  * Class  : CCPruner                                                              *
00007  * Web    : http://tmva.sourceforge.net                                           *
00008  *                                                                                *
00009  * Description: Cost Complexity Pruning                                           *
00010  * 
00011  * Author: Doug Schouten (dschoute@sfu.ca)
00012  *
00013  *                                                                                *
00014  * Copyright (c) 2007:                                                            *
00015  *      CERN, Switzerland                                                         *
00016  *      MPI-K Heidelberg, Germany                                                 *
00017  *      U. of Texas at Austin, USA                                                *
00018  *                                                                                *
00019  * Redistribution and use in source and binary forms, with or without             *
00020  * modification, are permitted according to the terms listed in LICENSE           *
00021  * (http://tmva.sourceforge.net/LICENSE)                                          *
00022  **********************************************************************************/
00023 
00024 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
00025 // CCPruner - a helper class to prune a decision tree using the Cost Complexity method                    //
00026 // (see Classification and Regression Trees by Leo Breiman et al)                                         //
00027 //                                                                                                        //
00028 // Some definitions:                                                                                      //
00029 //                                                                                                        //
00030 // T_max - the initial, usually highly overtrained tree, that is to be pruned back                        // 
00031 // R(T) - quality index (Gini, misclassification rate, or other) of a tree T                              //
00032 // ~T - set of terminal nodes in T                                                                        //
00033 // T' - the pruned subtree of T_max that has the best quality index R(T')                                 //
00034 // alpha - the prune strength parameter in Cost Complexity pruning (R_alpha(T) = R(T) + alpha// |~T|)     //
00035 //                                                                                                        //
00036 // There are two running modes in CCPruner: (i) one may select a prune strength and prune back            //
00037 // the tree T_max until the criterion                                                                     //
00038 //             R(T) - R(t)                                                                                //
00039 //  alpha <    ----------                                                                                 //
00040 //             |~T_t| - 1                                                                                 //
00041 //                                                                                                        //
00042 // is true for all nodes t in T, or (ii) the algorithm finds the sequence of critical points              //
00043 // alpha_k < alpha_k+1 ... < alpha_K such that T_K = root(T_max) and then selects the optimally-pruned    //
00044 // subtree, defined to be the subtree with the best quality index for the validation sample.              //
00045 ////////////////////////////////////////////////////////////////////////////////////////////////////////////
00046 
00047 
00048 #ifndef ROOT_TMVA_DecisionTree
00049 #include "TMVA/DecisionTree.h"
00050 #endif
00051 
00052 /* #ifndef ROOT_TMVA_DecisionTreeNode */
00053 /* #include "TMVA/DecisionTreeNode.h" */
00054 /* #endif */
00055 
00056 #ifndef ROOT_TMVA_Event
00057 #include "TMVA/Event.h"
00058 #endif
00059 
00060 namespace TMVA {
00061    class DecisionTreeNode;
00062    class SeparationBase;
00063 
00064    class CCPruner {
00065    public: 
00066       typedef std::vector<Event*> EventList;
00067 
00068       CCPruner( DecisionTree* t_max, 
00069                 const EventList* validationSample,
00070                 SeparationBase* qualityIndex = NULL );
00071 
00072       CCPruner( DecisionTree* t_max, 
00073                 const DataSet* validationSample,
00074                 SeparationBase* qualityIndex = NULL );
00075 
00076       ~CCPruner( );
00077 
00078       // set the pruning strength parameter alpha (if alpha < 0, the optimal alpha is calculated)
00079       void SetPruneStrength( Float_t alpha = -1.0 );
00080 
00081       void Optimize( );
00082 
00083       // return the list of pruning locations to define the optimal subtree T' of T_max
00084       std::vector<TMVA::DecisionTreeNode*> GetOptimalPruneSequence( ) const; 
00085 
00086       // return the quality index from the validation sample for the optimal subtree T'
00087       inline Float_t GetOptimalQualityIndex( ) const { return (fOptimalK >= 0 && fQualityIndexList.size() > 0 ?
00088                                                                fQualityIndexList[fOptimalK] : -1.0); }
00089 
00090       // return the prune strength (=alpha) corresponding to the prune sequence
00091       inline Float_t GetOptimalPruneStrength( ) const { return (fOptimalK >= 0 && fPruneStrengthList.size() > 0 ?
00092                                                                 fPruneStrengthList[fOptimalK] : -1.0); }
00093    
00094    private:
00095       Float_t              fAlpha; //! regularization parameter in CC pruning
00096       const EventList*     fValidationSample; //! the event sample to select the optimally-pruned tree
00097       const DataSet*       fValidationDataSet; //! the event sample to select the optimally-pruned tree
00098       SeparationBase*      fQualityIndex; //! the quality index used to calculate R(t), R(T) = sum[t in ~T]{ R(t) }
00099       Bool_t               fOwnQIndex; //! flag indicates if fQualityIndex is owned by this
00100 
00101       DecisionTree*        fTree; //! (pruned) decision tree
00102 
00103       std::vector<TMVA::DecisionTreeNode*> fPruneSequence; //! map of weakest links (i.e., branches to prune) -> pruning index
00104       std::vector<Float_t> fPruneStrengthList;  //! map of alpha -> pruning index
00105       std::vector<Float_t> fQualityIndexList;   //! map of R(T) -> pruning index
00106 
00107       Int_t                fOptimalK;           //! index of the optimal tree in the pruned tree sequence
00108       Bool_t               fDebug;              //! debug flag
00109    };
00110 }
00111 
00112 inline void TMVA::CCPruner::SetPruneStrength( Float_t alpha ) {
00113   fAlpha = (alpha > 0 ? alpha : 0.0);
00114 }
00115     
00116 
00117 #endif
00118 
00119 

Generated on Tue Jul 5 14:27:18 2011 for ROOT_528-00b_version by  doxygen 1.5.1