00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 #ifndef ROOT_TMVA_MethodRuleFit
00027 #define ROOT_TMVA_MethodRuleFit
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037 #ifndef ROOT_TMVA_MethodBase
00038 #include "TMVA/MethodBase.h"
00039 #endif
00040 #ifndef ROOT_TMatrixDfwd
00041 #include "TMatrixDfwd.h"
00042 #endif
00043 #ifndef ROOT_TVectorD
00044 #include "TVectorD.h"
00045 #endif
00046 #ifndef ROOT_TMVA_DecisionTree
00047 #include "TMVA/DecisionTree.h"
00048 #endif
00049 #ifndef ROOT_TMVA_RuleFit
00050 #include "TMVA/RuleFit.h"
00051 #endif
00052
00053 namespace TMVA {
00054
00055 class SeparationBase;
00056
00057 class MethodRuleFit : public MethodBase {
00058
00059 public:
00060
00061 MethodRuleFit( const TString& jobName,
00062 const TString& methodTitle,
00063 DataSetInfo& theData,
00064 const TString& theOption = "",
00065 TDirectory* theTargetDir = 0 );
00066
00067 MethodRuleFit( DataSetInfo& theData,
00068 const TString& theWeightFile,
00069 TDirectory* theTargetDir = NULL );
00070
00071 virtual ~MethodRuleFit( void );
00072
00073 virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t );
00074
00075
00076 void Train( void );
00077
00078 using MethodBase::ReadWeightsFromStream;
00079
00080
00081 void AddWeightsXMLTo ( void* parent ) const;
00082
00083
00084 void ReadWeightsFromStream( istream& istr );
00085 void ReadWeightsFromXML ( void* wghtnode );
00086
00087
00088 Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
00089
00090
00091 void WriteMonitoringHistosToFile( void ) const;
00092
00093
00094 const Ranking* CreateRanking();
00095
00096 Bool_t UseBoost() const { return fUseBoost; }
00097
00098
00099 RuleFit* GetRuleFitPtr() { return &fRuleFit; }
00100 const RuleFit* GetRuleFitConstPtr() const { return &fRuleFit; }
00101 TDirectory* GetMethodBaseDir() const { return BaseDir(); }
00102 const std::vector<TMVA::Event*>& GetTrainingEvents() const { return fEventSample; }
00103 const std::vector<TMVA::DecisionTree*>& GetForest() const { return fForest; }
00104 Int_t GetNTrees() const { return fNTrees; }
00105 Double_t GetTreeEveFrac() const { return fTreeEveFrac; }
00106 const SeparationBase* GetSeparationBaseConst() const { return fSepType; }
00107 SeparationBase* GetSeparationBase() const { return fSepType; }
00108 TMVA::DecisionTree::EPruneMethod GetPruneMethod() const { return fPruneMethod; }
00109 Double_t GetPruneStrength() const { return fPruneStrength; }
00110 Double_t GetMinFracNEve() const { return fMinFracNEve; }
00111 Double_t GetMaxFracNEve() const { return fMaxFracNEve; }
00112 Int_t GetNCuts() const { return fNCuts; }
00113
00114 Int_t GetGDNPathSteps() const { return fGDNPathSteps; }
00115 Double_t GetGDPathStep() const { return fGDPathStep; }
00116 Double_t GetGDErrScale() const { return fGDErrScale; }
00117 Double_t GetGDPathEveFrac() const { return fGDPathEveFrac; }
00118 Double_t GetGDValidEveFrac() const { return fGDValidEveFrac; }
00119
00120 Double_t GetLinQuantile() const { return fLinQuantile; }
00121
00122 const TString GetRFWorkDir() const { return fRFWorkDir; }
00123 Int_t GetRFNrules() const { return fRFNrules; }
00124 Int_t GetRFNendnodes() const { return fRFNendnodes; }
00125
00126 protected:
00127
00128
00129 void MakeClassSpecific( std::ostream&, const TString& ) const;
00130
00131 void MakeClassRuleCuts( std::ostream& ) const;
00132
00133 void MakeClassLinear( std::ostream& ) const;
00134
00135
00136 void GetHelpMessage() const;
00137
00138
00139 void Init( void );
00140
00141
00142 void InitEventSample( void );
00143
00144
00145 void InitMonitorNtuple();
00146
00147 void TrainTMVARuleFit();
00148 void TrainJFRuleFit();
00149
00150 private:
00151
00152
00153 template<typename T>
00154 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
00155
00156 template<typename T>
00157 inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
00158
00159 template<typename T>
00160 inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
00161
00162
00163 void DeclareOptions();
00164 void ProcessOptions();
00165
00166 RuleFit fRuleFit;
00167 std::vector<TMVA::Event *> fEventSample;
00168 Double_t fSignalFraction;
00169
00170
00171 TTree *fMonitorNtuple;
00172 Double_t fNTImportance;
00173 Double_t fNTCoefficient;
00174 Double_t fNTSupport;
00175 Int_t fNTNcuts;
00176 Int_t fNTNvars;
00177 Double_t fNTPtag;
00178 Double_t fNTPss;
00179 Double_t fNTPsb;
00180 Double_t fNTPbs;
00181 Double_t fNTPbb;
00182 Double_t fNTSSB;
00183 Int_t fNTType;
00184
00185
00186 TString fRuleFitModuleS;
00187 Bool_t fUseRuleFitJF;
00188 TString fRFWorkDir;
00189 Int_t fRFNrules;
00190 Int_t fRFNendnodes;
00191 std::vector<DecisionTree *> fForest;
00192 Int_t fNTrees;
00193 Double_t fTreeEveFrac;
00194 SeparationBase *fSepType;
00195 Double_t fMinFracNEve;
00196 Double_t fMaxFracNEve;
00197 Int_t fNCuts;
00198 TString fSepTypeS;
00199 TString fPruneMethodS;
00200 TMVA::DecisionTree::EPruneMethod fPruneMethod;
00201 Double_t fPruneStrength;
00202 TString fForestTypeS;
00203 Bool_t fUseBoost;
00204
00205 Double_t fGDPathEveFrac;
00206 Double_t fGDValidEveFrac;
00207 Double_t fGDTau;
00208 Double_t fGDTauPrec;
00209 Double_t fGDTauMin;
00210 Double_t fGDTauMax;
00211 UInt_t fGDTauScan;
00212 Double_t fGDPathStep;
00213 Int_t fGDNPathSteps;
00214 Double_t fGDErrScale;
00215 Double_t fMinimp;
00216
00217 TString fModelTypeS;
00218 Double_t fRuleMinDist;
00219 Double_t fLinQuantile;
00220
00221 ClassDef(MethodRuleFit,0)
00222 };
00223
00224 }
00225
00226
00227
00228 template<typename T>
00229 inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
00230 {
00231
00232 if (var>vmax) return 1;
00233 if (var<vmin) return -1;
00234 return 0;
00235 }
00236
00237
00238 template<typename T>
00239 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
00240 {
00241
00242
00243 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
00244 Bool_t modif=kFALSE;
00245 if (dir==1) {
00246 modif = kTRUE;
00247 var=vmax;
00248 }
00249 if (dir==-1) {
00250 modif = kTRUE;
00251 var=vmin;
00252 }
00253 if (modif) {
00254 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
00255 }
00256 return modif;
00257 }
00258
00259
00260 template<typename T>
00261 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
00262 {
00263
00264
00265 Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
00266 Bool_t modif=kFALSE;
00267 if (dir!=0) {
00268 modif = kTRUE;
00269 var=vdef;
00270 }
00271 if (modif) {
00272 mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
00273 }
00274 return modif;
00275 }
00276
00277
00278 #endif // MethodRuleFit_H