00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef ROOT_TMVA_DecisionTree
00030 #define ROOT_TMVA_DecisionTree
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040 #ifndef ROOT_TH2
00041 #include "TH2.h"
00042 #endif
00043
00044 #ifndef ROOT_TMVA_Types
00045 #include "TMVA/Types.h"
00046 #endif
00047 #ifndef ROOT_TMVA_DecisionTreeNode
00048 #include "TMVA/DecisionTreeNode.h"
00049 #endif
00050 #ifndef ROOT_TMVA_BinaryTree
00051 #include "TMVA/BinaryTree.h"
00052 #endif
00053 #ifndef ROOT_TMVA_BinarySearchTree
00054 #include "TMVA/BinarySearchTree.h"
00055 #endif
00056 #ifndef ROOT_TMVA_SeparationBase
00057 #include "TMVA/SeparationBase.h"
00058 #endif
00059 #ifndef ROOT_TMVA_RegressionVariance
00060 #include "TMVA/RegressionVariance.h"
00061 #endif
00062
00063 class TRandom3;
00064
00065 namespace TMVA {
00066
00067 class Event;
00068
00069 class DecisionTree : public BinaryTree {
00070
00071 private:
00072
00073 static const Int_t fgRandomSeed;
00074
00075 public:
00076
00077 typedef std::vector<TMVA::Event*> EventList;
00078
00079
00080 DecisionTree( void );
00081
00082
00083 DecisionTree( SeparationBase *sepType, Int_t minSize,
00084 Int_t nCuts,
00085 UInt_t cls =0,
00086 Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE,
00087 UInt_t nNodesMax=999999, UInt_t nMaxDepth=9999999,
00088 Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
00089 Int_t treeID = 0);
00090
00091
00092 DecisionTree (const DecisionTree &d);
00093
00094 virtual ~DecisionTree( void );
00095
00096
00097 virtual DecisionTreeNode* GetRoot() const { return dynamic_cast<TMVA::DecisionTreeNode*>(fRoot); }
00098 virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
00099 virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
00100 static DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
00101 virtual const char* ClassName() const { return "DecisionTree"; }
00102
00103
00104
00105 UInt_t BuildTree( const EventList & eventSample,
00106 DecisionTreeNode *node = NULL);
00107
00108
00109 Double_t TrainNode( const EventList & eventSample, DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
00110 Double_t TrainNodeFast( const EventList & eventSample, DecisionTreeNode *node );
00111 Double_t TrainNodeFull( const EventList & eventSample, DecisionTreeNode *node );
00112 void GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
00113 std::vector<Double_t> GetFisherCoefficients(const EventList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
00114
00115
00116
00117
00118 void FillTree( EventList & eventSample);
00119
00120
00121
00122 void FillEvent( TMVA::Event & event,
00123 TMVA::DecisionTreeNode *node );
00124
00125
00126
00127 Double_t CheckEvent( const TMVA::Event & , Bool_t UseYesNoLeaf = kFALSE ) const;
00128 TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
00129
00130
00131 std::vector< Double_t > GetVariableImportance();
00132
00133 Double_t GetVariableImportance(UInt_t ivar);
00134
00135
00136
00137 void ClearTree();
00138
00139
00140 enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
00141 void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
00142
00143
00144 Double_t PruneTree( EventList* validationSample = NULL );
00145
00146
00147 void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
00148 Double_t GetPruneStrength( ) const { return fPruneStrength; }
00149
00150
00151 void ApplyValidationSample( const EventList* validationSample ) const;
00152
00153
00154 Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
00155
00156
00157 void CheckEventWithPrunedTree( const TMVA::Event& ) const;
00158
00159
00160 Double_t GetSumWeights( const EventList* validationSample ) const;
00161
00162 void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
00163 Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
00164
00165 void DescendTree( Node *n = NULL );
00166 void SetParentTreeInNodes( Node *n = NULL );
00167
00168
00169
00170
00171 Node* GetNode( ULong_t sequence, UInt_t depth );
00172
00173 UInt_t CleanTree(DecisionTreeNode *node=NULL);
00174
00175 void PruneNode(TMVA::DecisionTreeNode *node);
00176
00177
00178
00179 void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
00180
00181
00182 UInt_t CountLeafNodes(TMVA::Node *n = NULL);
00183
00184 void SetTreeID(Int_t treeID){fTreeID = treeID;};
00185 Int_t GetTreeID(){return fTreeID;};
00186
00187 Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
00188 void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
00189 Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
00190 inline void SetUseFisherCuts(Bool_t t=kTRUE) { fUseFisherCuts = t;}
00191 inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
00192 inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
00193
00194 private:
00195
00196
00197
00198
00199
00200
00201 Double_t SamplePurity(EventList eventSample);
00202
00203 UInt_t fNvars;
00204 Int_t fNCuts;
00205 Bool_t fUseFisherCuts;
00206 Double_t fMinLinCorrForFisher;
00207 Bool_t fUseExclusiveVars;
00208
00209 SeparationBase *fSepType;
00210 RegressionVariance *fRegType;
00211
00212 Double_t fMinSize;
00213 Double_t fMinSepGain;
00214
00215 Bool_t fUseSearchTree;
00216 Double_t fPruneStrength;
00217
00218 EPruneMethod fPruneMethod;
00219
00220 Double_t fNodePurityLimit;
00221
00222 Bool_t fRandomisedTree;
00223 Int_t fUseNvars;
00224 Bool_t fUsePoissonNvars;
00225
00226 TRandom3 *fMyTrandom;
00227
00228 std::vector< Double_t > fVariableImportance;
00229
00230 UInt_t fNNodesMax;
00231 UInt_t fMaxDepth;
00232 UInt_t fClass;
00233
00234 static const Int_t fgDebugLevel = 0;
00235 Int_t fTreeID;
00236
00237 Types::EAnalysisType fAnalysisType;
00238
00239 ClassDef(DecisionTree,0)
00240 };
00241
00242 }
00243
00244 #endif