00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030 #ifndef ROOT_TMVA_MethodBDT
00031 #define ROOT_TMVA_MethodBDT
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041 #include <vector>
00042 #ifndef ROOT_TH2
00043 #include "TH2.h"
00044 #endif
00045 #ifndef ROOT_TTree
00046 #include "TTree.h"
00047 #endif
00048 #ifndef ROOT_TMVA_MethodBase
00049 #include "TMVA/MethodBase.h"
00050 #endif
00051 #ifndef ROOT_TMVA_DecisionTree
00052 #include "TMVA/DecisionTree.h"
00053 #endif
00054 #ifndef ROOT_TMVA_Event
00055 #include "TMVA/Event.h"
00056 #endif
00057
00058 namespace TMVA {
00059
00060 class SeparationBase;
00061
00062 class MethodBDT : public MethodBase {
00063
00064 public:
00065
00066 MethodBDT( const TString& jobName,
00067 const TString& methodTitle,
00068 DataSetInfo& theData,
00069 const TString& theOption = "",
00070 TDirectory* theTargetDir = 0 );
00071
00072
00073 MethodBDT( DataSetInfo& theData,
00074 const TString& theWeightFile,
00075 TDirectory* theTargetDir = NULL );
00076
00077 virtual ~MethodBDT( void );
00078
00079 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t numberTargets );
00080
00081
00082
00083
00084 void InitEventSample();
00085
00086
00087 virtual std::map<TString,Double_t> OptimizeTuningParameters(TString fomType="ROCIntegral", TString fitType="FitGA");
00088 virtual void SetTuneParameters(std::map<TString,Double_t> tuneParameters);
00089
00090
00091 void Train( void );
00092
00093
00094 void Reset( void );
00095
00096 using MethodBase::ReadWeightsFromStream;
00097
00098
00099 void AddWeightsXMLTo( void* parent ) const;
00100
00101
00102 void ReadWeightsFromStream( istream& istr );
00103 void ReadWeightsFromXML(void* parent);
00104
00105
00106 void WriteMonitoringHistosToFile( void ) const;
00107
00108
00109 Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0);
00110
00111 private:
00112 Double_t GetMvaValue( Double_t* err, Double_t* errUpper, UInt_t useNTrees );
00113
00114 public:
00115 const std::vector<Float_t>& GetMulticlassValues();
00116
00117
00118 const std::vector<Float_t>& GetRegressionValues();
00119
00120
00121 Double_t Boost( std::vector<TMVA::Event*>, DecisionTree *dt, Int_t iTree, UInt_t cls = 0);
00122
00123
00124 const Ranking* CreateRanking();
00125
00126
00127 void DeclareOptions();
00128 void ProcessOptions();
00129 void SetMaxDepth(Int_t d){fMaxDepth = d;}
00130 void SetNodeMinEvents(Int_t d){fNodeMinEvents = d;}
00131 void SetNTrees(Int_t d){fNTrees = d;}
00132 void SetAdaBoostBeta(Double_t b){fAdaBoostBeta = b;}
00133 void SetNodePurityLimit(Double_t l){fNodePurityLimit = l;}
00134
00135
00136
00137 inline const std::vector<TMVA::DecisionTree*> & GetForest() const;
00138
00139
00140 inline const std::vector<TMVA::Event*> & GetTrainingEvents() const;
00141
00142 inline const std::vector<double> & GetBoostWeights() const;
00143
00144
00145 std::vector<Double_t> GetVariableImportance();
00146 Double_t GetVariableImportance(UInt_t ivar);
00147
00148 Double_t TestTreeQuality( DecisionTree *dt );
00149
00150
00151 void MakeClassSpecific( std::ostream&, const TString& ) const;
00152
00153
00154 void MakeClassSpecificHeader( std::ostream&, const TString& ) const;
00155
00156 void MakeClassInstantiateNode( DecisionTreeNode *n, std::ostream& fout,
00157 const TString& className ) const;
00158
00159 void GetHelpMessage() const;
00160
00161 virtual Bool_t IsSignalLike() { return GetMvaValue() > 0;}
00162 protected:
00163 void DeclareCompatibilityOptions();
00164
00165 private:
00166
00167 void Init( void );
00168
00169
00170 Double_t AdaBoost( std::vector<TMVA::Event*>, DecisionTree *dt );
00171
00172
00173 Double_t Bagging( std::vector<TMVA::Event*>, Int_t iTree );
00174
00175
00176 Double_t RegBoost( std::vector<TMVA::Event*>, DecisionTree *dt );
00177
00178
00179 Double_t AdaBoostR2( std::vector<TMVA::Event*>, DecisionTree *dt );
00180
00181
00182
00183
00184 Double_t GradBoost( std::vector<TMVA::Event*>, DecisionTree *dt, UInt_t cls = 0);
00185 Double_t GradBoostRegression(std::vector<TMVA::Event*>, DecisionTree *dt );
00186 void InitGradBoost( std::vector<TMVA::Event*>);
00187 void UpdateTargets( std::vector<TMVA::Event*>, UInt_t cls = 0);
00188 void UpdateTargetsRegression( std::vector<TMVA::Event*>,Bool_t first=kFALSE);
00189 Double_t GetGradBoostMVA(TMVA::Event& e, UInt_t nTrees);
00190 void GetRandomSubSample();
00191 Double_t GetWeightedQuantile(std::vector<std::pair<Double_t, Double_t> > vec, const Double_t quantile, const Double_t SumOfWeights = 0.0);
00192
00193 std::vector<TMVA::Event*> fEventSample;
00194 std::vector<TMVA::Event*> fValidationSample;
00195 std::vector<TMVA::Event*> fSubSample;
00196 Int_t fNTrees;
00197 std::vector<DecisionTree*> fForest;
00198 std::vector<double> fBoostWeights;
00199 Bool_t fRenormByClass;
00200 TString fBoostType;
00201 Double_t fAdaBoostBeta;
00202 TString fAdaBoostR2Loss;
00203 Double_t fTransitionPoint;
00204 Double_t fShrinkage;
00205 Bool_t fBaggedGradBoost;
00206 Double_t fSampleFraction;
00207 Double_t fSumOfWeights;
00208 std::map< TMVA::Event*, std::pair<Double_t, Double_t> > fWeightedResiduals;
00209 std::map< TMVA::Event*,std::vector<double> > fResiduals;
00210
00211
00212 SeparationBase *fSepType;
00213 TString fSepTypeS;
00214 Int_t fNodeMinEvents;
00215
00216 Int_t fNCuts;
00217 Bool_t fUseFisherCuts;
00218 Double_t fMinLinCorrForFisher;
00219 Bool_t fUseExclusiveVars;
00220 Bool_t fUseYesNoLeaf;
00221 Double_t fNodePurityLimit;
00222 Bool_t fUseWeightedTrees;
00223 UInt_t fNNodesMax;
00224 UInt_t fMaxDepth;
00225
00226 DecisionTree::EPruneMethod fPruneMethod;
00227 TString fPruneMethodS;
00228 Double_t fPruneStrength;
00229 Bool_t fPruneBeforeBoost;
00230 Double_t fFValidationEvents;
00231 Bool_t fAutomatic;
00232 Bool_t fRandomisedTrees;
00233 UInt_t fUseNvars;
00234 Bool_t fUsePoissonNvars;
00235 UInt_t fUseNTrainEvents;
00236
00237 Double_t fSampleSizeFraction;
00238 Bool_t fNoNegWeightsInTraining;
00239
00240
00241
00242
00243 TTree* fMonitorNtuple;
00244 Int_t fITree;
00245 Double_t fBoostWeight;
00246 Double_t fErrorFraction;
00247
00248 std::vector<Double_t> fVariableImportance;
00249
00250
00251 static const Int_t fgDebugLevel;
00252
00253
00254
00255 ClassDef(MethodBDT,0)
00256 };
00257
00258 }
00259
00260 const std::vector<TMVA::DecisionTree*>& TMVA::MethodBDT::GetForest() const { return fForest; }
00261 const std::vector<TMVA::Event*>& TMVA::MethodBDT::GetTrainingEvents() const { return fEventSample; }
00262 const std::vector<double>& TMVA::MethodBDT::GetBoostWeights() const { return fBoostWeights; }
00263
00264 #endif