DecisionTree.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DecisionTree.h 37986 2011-02-04 21:42:15Z pcanal $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss 
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : DecisionTree                                                          *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation of a Decision Tree                                         *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00017  *                                                                                *
00018  * Copyright (c) 2005:                                                            *
00019  *      CERN, Switzerland                                                         * 
00020  *      U. of Victoria, Canada                                                    * 
00021  *      MPI-K Heidelberg, Germany                                                 * 
00022  *                                                                                *
00023  * Redistribution and use in source and binary forms, with or without             *
00024  * modification, are permitted according to the terms listed in LICENSE           *
00025  * (http://mva.sourceforge.net/license.txt)                                       *
00026  *                                                                                *
00027  **********************************************************************************/
00028 
00029 #ifndef ROOT_TMVA_DecisionTree
00030 #define ROOT_TMVA_DecisionTree
00031 
00032 //////////////////////////////////////////////////////////////////////////
00033 //                                                                      //
00034 // DecisionTree                                                         //
00035 //                                                                      //
00036 // Implementation of a Decision Tree                                    //
00037 //                                                                      //
00038 //////////////////////////////////////////////////////////////////////////
00039 
00040 #ifndef ROOT_TH2
00041 #include "TH2.h"
00042 #endif
00043 
00044 #ifndef ROOT_TMVA_Types
00045 #include "TMVA/Types.h"
00046 #endif
00047 #ifndef ROOT_TMVA_DecisionTreeNode
00048 #include "TMVA/DecisionTreeNode.h"
00049 #endif
00050 #ifndef ROOT_TMVA_BinaryTree
00051 #include "TMVA/BinaryTree.h"
00052 #endif
00053 #ifndef ROOT_TMVA_BinarySearchTree
00054 #include "TMVA/BinarySearchTree.h"
00055 #endif
00056 #ifndef ROOT_TMVA_SeparationBase
00057 #include "TMVA/SeparationBase.h"
00058 #endif
00059 #ifndef ROOT_TMVA_RegressionVariance
00060 #include "TMVA/RegressionVariance.h"
00061 #endif
00062 
00063 class TRandom3;
00064 
00065 namespace TMVA {
00066 
00067    class Event;
00068 
00069    class DecisionTree : public BinaryTree {
00070 
00071    private:
00072 
00073       static const Int_t fgRandomSeed; // set nonzero for debugging and zero for random seeds
00074 
00075    public:
00076 
00077       typedef std::vector<TMVA::Event*> EventList;
00078 
00079       // the constructur needed for the "reading" of the decision tree from weight files
00080       DecisionTree( void );
00081 
00082       // the constructur needed for constructing the decision tree via training with events
00083       DecisionTree( SeparationBase *sepType, Int_t minSize,
00084                     Int_t nCuts,
00085                     UInt_t cls =0,
00086                     Bool_t randomisedTree=kFALSE, Int_t useNvars=0, Bool_t usePoissonNvars=kFALSE, 
00087                     UInt_t nNodesMax=999999, UInt_t nMaxDepth=9999999, 
00088                     Int_t iSeed=fgRandomSeed, Float_t purityLimit=0.5,
00089                     Int_t treeID = 0);
00090 
00091       // copy constructor
00092       DecisionTree (const DecisionTree &d);
00093 
00094       virtual ~DecisionTree( void );
00095 
00096       // Retrieves the address of the root node
00097       virtual DecisionTreeNode* GetRoot() const { return dynamic_cast<TMVA::DecisionTreeNode*>(fRoot); }
00098       virtual DecisionTreeNode * CreateNode(UInt_t) const { return new DecisionTreeNode(); }
00099       virtual BinaryTree* CreateTree() const { return new DecisionTree(); }
00100       static  DecisionTree* CreateFromXML(void* node, UInt_t tmva_Version_Code = TMVA_VERSION_CODE);
00101       virtual const char* ClassName() const { return "DecisionTree"; }
00102 
00103       // building of a tree by recursivly splitting the nodes
00104 
00105       UInt_t BuildTree( const EventList & eventSample,
00106                         DecisionTreeNode *node = NULL);
00107       // determine the way how a node is split (which variable, which cut value)
00108 
00109       Double_t TrainNode( const EventList & eventSample,  DecisionTreeNode *node ) { return TrainNodeFast( eventSample, node ); }
00110       Double_t TrainNodeFast( const EventList & eventSample,  DecisionTreeNode *node );
00111       Double_t TrainNodeFull( const EventList & eventSample,  DecisionTreeNode *node );
00112       void    GetRandomisedVariables(Bool_t *useVariable, UInt_t *variableMap, UInt_t & nVars);
00113       std::vector<Double_t>  GetFisherCoefficients(const EventList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher);
00114     
00115       // fill at tree with a given structure already (just see how many signa/bkgr
00116       // events end up in each node
00117 
00118       void FillTree( EventList & eventSample);
00119 
00120       // fill the existing the decision tree structure by filling event
00121       // in from the top node and see where they happen to end up
00122       void FillEvent( TMVA::Event & event,
00123                       TMVA::DecisionTreeNode *node  );
00124     
00125       // returns: 1 = Signal (right),  -1 = Bkg (left)
00126 
00127       Double_t CheckEvent( const TMVA::Event & , Bool_t UseYesNoLeaf = kFALSE ) const;     
00128       TMVA::DecisionTreeNode* GetEventNode(const TMVA::Event & e) const;
00129 
00130       // return the individual relative variable importance 
00131       std::vector< Double_t > GetVariableImportance();
00132 
00133       Double_t GetVariableImportance(UInt_t ivar);
00134     
00135       // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
00136 
00137       void ClearTree();
00138     
00139       // set pruning method
00140       enum EPruneMethod { kExpectedErrorPruning=0, kCostComplexityPruning, kNoPruning };
00141       void SetPruneMethod( EPruneMethod m = kCostComplexityPruning ) { fPruneMethod = m; }
00142     
00143       // recursive pruning of the tree, validation sample required for automatic pruning
00144       Double_t PruneTree( EventList* validationSample = NULL );
00145     
00146       // manage the pruning strength parameter (iff < 0 -> automate the pruning process)
00147       void SetPruneStrength( Double_t p ) { fPruneStrength = p; }
00148       Double_t GetPruneStrength( ) const { return fPruneStrength; }
00149 
00150       // apply pruning validation sample to a decision tree
00151       void ApplyValidationSample( const EventList* validationSample ) const;
00152     
00153       // return the misclassification rate of a pruned tree
00154       Double_t TestPrunedTreeQuality( const DecisionTreeNode* dt = NULL, Int_t mode=0 ) const;
00155     
00156       // pass a single validation event throught a pruned decision tree
00157       void CheckEventWithPrunedTree( const TMVA::Event& ) const;
00158 
00159       // calculate the normalization factor for a pruning validation sample
00160       Double_t GetSumWeights( const EventList* validationSample ) const;
00161     
00162       void SetNodePurityLimit( Double_t p ) { fNodePurityLimit = p; }
00163       Double_t GetNodePurityLimit( ) const { return fNodePurityLimit; }
00164 
00165       void DescendTree( Node *n = NULL );
00166       void SetParentTreeInNodes( Node *n = NULL );
00167         
00168       // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
00169       // is coded as a sequence of left-right moves starting from the root, coded as
00170       // 0-1 bit patterns stored in the "long-integer" together with the depth
00171       Node* GetNode( ULong_t sequence, UInt_t depth );
00172     
00173       UInt_t CleanTree(DecisionTreeNode *node=NULL);
00174      
00175       void PruneNode(TMVA::DecisionTreeNode *node);    
00176     
00177       // prune a node from the tree without deleting its descendants; allows one to
00178       // effectively prune a tree many times without making deep copies
00179       void PruneNodeInPlace( TMVA::DecisionTreeNode* node );
00180     
00181 
00182       UInt_t CountLeafNodes(TMVA::Node *n = NULL);
00183 
00184       void  SetTreeID(Int_t treeID){fTreeID = treeID;};
00185       Int_t GetTreeID(){return fTreeID;};
00186 
00187       Bool_t DoRegression() const { return fAnalysisType == Types::kRegression; }
00188       void SetAnalysisType (Types::EAnalysisType t) { fAnalysisType = t;}
00189       Types::EAnalysisType GetAnalysisType ( void ) { return fAnalysisType;}
00190       inline void SetUseFisherCuts(Bool_t t=kTRUE)  { fUseFisherCuts = t;}
00191       inline void SetMinLinCorrForFisher(Double_t min){fMinLinCorrForFisher = min;}
00192       inline void SetUseExclusiveVars(Bool_t t=kTRUE){fUseExclusiveVars = t;}
00193 
00194    private:
00195       // utility functions
00196      
00197       // calculate the Purity out of the number of sig and bkg events collected
00198       // from individual samples.
00199     
00200       // calculates the purity S/(S+B) of a given event sample
00201       Double_t SamplePurity(EventList eventSample);
00202 
00203       UInt_t    fNvars;          // number of variables used to separate S and B
00204       Int_t     fNCuts;          // number of grid point in variable cut scans
00205       Bool_t    fUseFisherCuts;  // use multivariate splits using the Fisher criterium
00206       Double_t  fMinLinCorrForFisher; // the minimum linear correlation between two variables demanded for use in fisher criterium in node splitting
00207       Bool_t    fUseExclusiveVars; // individual variables already used in fisher criterium are not anymore analysed individually for node splitting
00208 
00209       SeparationBase *fSepType;  // the separation crition
00210       RegressionVariance *fRegType;  // the separation crition used in Regression
00211     
00212       Double_t  fMinSize;        // min number of events in node
00213       Double_t  fMinSepGain;     // min number of separation gain to perform node splitting
00214     
00215       Bool_t    fUseSearchTree;  // cut scan done with binary trees or simple event loop.
00216       Double_t  fPruneStrength;  // a parameter to set the "amount" of pruning..needs to be adjusted 
00217     
00218       EPruneMethod fPruneMethod; // method used for prunig 
00219     
00220       Double_t  fNodePurityLimit;// purity limit to decide whether a node is signal
00221     
00222       Bool_t    fRandomisedTree; // choose at each node splitting a random set of variables 
00223       Int_t     fUseNvars;       // the number of variables used in randomised trees;
00224       Bool_t    fUsePoissonNvars; // use "fUseNvars" not as fixed number but as mean of a possion distr. in each split
00225     
00226       TRandom3  *fMyTrandom;     // random number generator for randomised trees
00227     
00228       std::vector< Double_t > fVariableImportance; // the relative importance of the different variables 
00229 
00230       UInt_t     fNNodesMax;     // max # of nodes
00231       UInt_t     fMaxDepth;      // max depth
00232       UInt_t     fClass;         // class which is treated as signal when building the tree
00233 
00234       static const Int_t  fgDebugLevel = 0;     // debug level determining some printout/control plots etc.
00235       Int_t     fTreeID;        // just an ID number given to the tree.. makes debugging easier as tree knows who he is.
00236 
00237       Types::EAnalysisType  fAnalysisType;   // kClassification(=0=false) or kRegression(=1=true)
00238 
00239       ClassDef(DecisionTree,0)               // implementation of a Decision Tree
00240    };
00241   
00242 } // namespace TMVA
00243 
00244 #endif 

Generated on Tue Jul 5 14:27:27 2011 for ROOT_528-00b_version by  doxygen 1.5.1