00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 #ifndef ROOT_TMVA_RuleFit
00028 #define ROOT_TMVA_RuleFit
00029
00030 #include <algorithm>
00031
00032 #ifndef ROOT_TMVA_DecisionTree
00033 #include "TMVA/DecisionTree.h"
00034 #endif
00035 #ifndef ROOT_TMVA_RuleEnsemble
00036 #include "TMVA/RuleEnsemble.h"
00037 #endif
00038 #ifndef ROOT_TMVA_RuleFitParams
00039 #include "TMVA/RuleFitParams.h"
00040 #endif
00041 #ifndef ROOT_TMVA_Event
00042 #include "TMVA/Event.h"
00043 #endif
00044
00045 namespace TMVA {
00046
00047
00048 class MethodBase;
00049 class MethodRuleFit;
00050 class MsgLogger;
00051
00052 class RuleFit {
00053
00054 public:
00055
00056
00057 RuleFit( const TMVA::MethodBase *rfbase );
00058
00059
00060 RuleFit( void );
00061
00062 virtual ~RuleFit( void );
00063
00064 void InitNEveEff();
00065 void InitPtrs( const TMVA::MethodBase *rfbase );
00066 void Initialize( const TMVA::MethodBase *rfbase );
00067
00068 void SetMsgType( EMsgType t );
00069
00070 void SetTrainingEvents( const std::vector<TMVA::Event *> & el );
00071
00072 void ReshuffleEvents() { std::random_shuffle(fTrainingEventsRndm.begin(),fTrainingEventsRndm.end()); }
00073
00074 void SetMethodBase( const MethodBase *rfbase );
00075
00076
00077 void MakeForest();
00078
00079
00080 void BuildTree( TMVA::DecisionTree *dt );
00081
00082
00083 void SaveEventWeights();
00084
00085
00086 void RestoreEventWeights();
00087
00088
00089 void Boost( TMVA::DecisionTree *dt );
00090
00091
00092 void ForestStatistics();
00093
00094
00095 Double_t EvalEvent( const Event& e );
00096
00097
00098 Double_t CalcWeightSum( const std::vector<TMVA::Event *> *events, UInt_t neve=0 );
00099
00100
00101 void FitCoefficients();
00102
00103
00104 void CalcImportance();
00105
00106
00107 void SetModelLinear() { fRuleEnsemble.SetModelLinear(); }
00108
00109 void SetModelRules() { fRuleEnsemble.SetModelRules(); }
00110
00111 void SetModelFull() { fRuleEnsemble.SetModelFull(); }
00112
00113 void SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
00114
00115 void SetRuleMinDist( Double_t d ) { fRuleEnsemble.SetRuleMinDist(d); }
00116
00117 void SetGDTau( Double_t t=0.0 ) { fRuleFitParams.SetGDTau(t); }
00118 void SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
00119 void SetGDNPathSteps( Int_t n=100 ) { fRuleFitParams.SetGDNPathSteps(n); }
00120
00121 void SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
00122 void UseImportanceVisHists() { fVisHistsUseImp = kTRUE; }
00123 void UseCoefficientsVisHists() { fVisHistsUseImp = kFALSE; }
00124 void MakeVisHists();
00125 void FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
00126 void FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
00127 void FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
00128 void FillLin(TH2F* h2,Int_t vind);
00129 void FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
00130 void NormVisHists(std::vector<TH2F *> & hlist);
00131 void MakeDebugHists();
00132 Bool_t GetCorrVars(TString & title, TString & var1, TString & var2);
00133
00134 UInt_t GetNTreeSample() const { return fNTreeSample; }
00135 Double_t GetNEveEff() const { return fNEveEffTrain; }
00136 const Event* GetTrainingEvent(UInt_t i) const { return static_cast< const Event *>(fTrainingEvents[i]); }
00137 Double_t GetTrainingEventWeight(UInt_t i) const { return fTrainingEvents[i]->GetWeight(); }
00138
00139
00140
00141 const std::vector< TMVA::Event * > & GetTrainingEvents() const { return fTrainingEvents; }
00142
00143
00144
00145 void GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
00146
00147 const std::vector< const TMVA::DecisionTree *> & GetForest() const { return fForest; }
00148 const RuleEnsemble & GetRuleEnsemble() const { return fRuleEnsemble; }
00149 RuleEnsemble * GetRuleEnsemblePtr() { return &fRuleEnsemble; }
00150 const RuleFitParams & GetRuleFitParams() const { return fRuleFitParams; }
00151 RuleFitParams * GetRuleFitParamsPtr() { return &fRuleFitParams; }
00152 const MethodRuleFit * GetMethodRuleFit() const { return fMethodRuleFit; }
00153 const MethodBase * GetMethodBase() const { return fMethodBase; }
00154
00155 private:
00156
00157
00158 RuleFit( const RuleFit & other );
00159
00160
00161 void Copy( const RuleFit & other );
00162
00163 std::vector<TMVA::Event *> fTrainingEvents;
00164 std::vector<TMVA::Event *> fTrainingEventsRndm;
00165 std::vector<Double_t> fEventWeights;
00166 UInt_t fNTreeSample;
00167
00168 Double_t fNEveEffTrain;
00169 std::vector< const TMVA::DecisionTree *> fForest;
00170 RuleEnsemble fRuleEnsemble;
00171 RuleFitParams fRuleFitParams;
00172 const MethodRuleFit *fMethodRuleFit;
00173 const MethodBase *fMethodBase;
00174 Bool_t fVisHistsUseImp;
00175
00176 mutable MsgLogger* fLogger;
00177 MsgLogger& Log() const { return *fLogger; }
00178
00179 static const Int_t randSEED = 0;
00180
00181 ClassDef(RuleFit,0)
00182 };
00183 }
00184
00185 #endif