00001 /********************************************************************************** 00002 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00003 * Package: TMVA * 00004 * Class : TMVA::DecisionTree * 00005 * Web : http://tmva.sourceforge.net * 00006 * * 00007 * Description: * 00008 * IPruneTool - a helper interface class to prune a decision tree * 00009 * * 00010 * Authors (alphabetical): * 00011 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada * 00012 * * 00013 * Copyright (c) 2005: * 00014 * CERN, Switzerland * 00015 * MPI-K Heidelberg, Germany * 00016 * * 00017 * Redistribution and use in source and binary forms, with or without * 00018 * modification, are permitted according to the terms listed in LICENSE * 00019 * (http://mva.sourceforge.net/license.txt) * 00020 **********************************************************************************/ 00021 00022 #ifndef ROOT_TMVA_IPruneTool 00023 #define ROOT_TMVA_IPruneTool 00024 00025 #include <iosfwd> 00026 #include <vector> 00027 00028 #ifndef ROOT_TMVA_SeparationBase 00029 #include "TMVA/SeparationBase.h" 00030 #endif 00031 00032 #ifndef ROOT_TMVA_DecisionTree 00033 #include "TMVA/DecisionTree.h" 00034 #endif 00035 00036 namespace TMVA { 00037 00038 // class MsgLogger; 00039 00040 //////////////////////////////////////////////////////////// 00041 // Basic struct for saving relevant pruning information // 00042 //////////////////////////////////////////////////////////// 00043 class PruningInfo { 00044 00045 public: 00046 00047 PruningInfo( ) : QualityIndex(0), PruneStrength(0), PruneSequence(0) {} 00048 PruningInfo( Double_t q, Double_t alpha, std::vector<DecisionTreeNode*> sequence ); 00049 Double_t QualityIndex; //! quality measure for a pruned subtree T of T_max 00050 Double_t PruneStrength; //! the regularization parameter for pruning 00051 std::vector<DecisionTreeNode*> PruneSequence; //! the sequence of pruning locations in T_max that yields T 00052 }; 00053 00054 inline PruningInfo::PruningInfo( Double_t q, Double_t alpha, std::vector<DecisionTreeNode*> sequence ) 00055 : QualityIndex(q), PruneStrength(alpha), PruneSequence(sequence) {} 00056 00057 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 00058 // IPruneTool - a helper interface class to prune a decision tree // 00059 // // 00060 // Any tool which implements the interface should provide two modes for tree pruning: // 00061 // 1. automatically find the "best" prune strength by minimizing the error rate on a test sample // 00062 // if SetAutomatic() is called, or if automatic = kTRUE argument is set in CalculatePruningInfo() // 00063 // In this case, the PruningInfo object returned contains the error rate of the optimally pruned // 00064 // tree, the optimal prune strength, and the sequence of nodes to prune to obtain the optimal // 00065 // pruned tree from the original DecisionTree // 00066 // // 00067 // 2. a user-provided pruning strength parameter is used to prune the tree, in which case the returned // 00068 // PruningInfo object has QualityIndex = -1, PruneStrength = user prune strength, and PruneSequence // 00069 // is the list of nodes to prune // 00070 //////////////////////////////////////////////////////////////////////////////////////////////////////////// 00071 00072 class IPruneTool { 00073 00074 public: 00075 00076 typedef std::vector<Event*> EventSample; 00077 00078 IPruneTool( ); 00079 virtual ~IPruneTool(); 00080 00081 public: 00082 00083 // returns the PruningInfo object for a given tree and test sample 00084 virtual PruningInfo* CalculatePruningInfo( DecisionTree* dt, const EventSample* testEvents = NULL, 00085 Bool_t isAutomatic = kFALSE ) = 0; 00086 00087 public: 00088 00089 // set the prune strength parameter to use in pruning 00090 inline void SetPruneStrength( Double_t alpha ) { fPruneStrength = alpha; } 00091 // return the prune strength the tool is using 00092 inline Double_t GetPruneStrength( ) const { return fPruneStrength; } 00093 00094 // if the prune strength parameter is < 0, the tool will automatically find an optimal strength 00095 // set the tool to automatic mode 00096 inline void SetAutomatic( ) { fPruneStrength = -1.0; }; 00097 inline Bool_t IsAutomatic( ) const { return fPruneStrength <= 0.0; } 00098 00099 protected: 00100 00101 // mutable MsgLogger* fLogger; //! output stream to save logging information 00102 // MsgLogger& Log() const { return *fLogger; } 00103 Double_t fPruneStrength; //! regularization parameter in pruning 00104 00105 00106 Double_t S, B; 00107 }; 00108 00109 inline IPruneTool::IPruneTool( ) : 00110 fPruneStrength(0.0), 00111 S(0), 00112 B(0) 00113 {} 00114 inline IPruneTool::~IPruneTool( ) {} 00115 00116 } 00117 00118 #endif