00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 #ifndef ROOT_TMVA_MethodRuleFit
00027 #define ROOT_TMVA_MethodRuleFit
00028 
00029 
00030 
00031 
00032 
00033 
00034 
00035 
00036 
00037 #ifndef ROOT_TMVA_MethodBase
00038 #include "TMVA/MethodBase.h"
00039 #endif
00040 #ifndef ROOT_TMatrixDfwd
00041 #include "TMatrixDfwd.h"
00042 #endif
00043 #ifndef ROOT_TVectorD
00044 #include "TVectorD.h"
00045 #endif
00046 #ifndef ROOT_TMVA_DecisionTree
00047 #include "TMVA/DecisionTree.h"
00048 #endif
00049 #ifndef ROOT_TMVA_RuleFit
00050 #include "TMVA/RuleFit.h"
00051 #endif
00052 
00053 namespace TMVA {
00054 
00055    class SeparationBase;
00056 
00057    class MethodRuleFit : public MethodBase {
00058 
00059    public:
00060 
00061       MethodRuleFit( const TString& jobName,
00062                      const TString& methodTitle, 
00063                      DataSetInfo& theData,
00064                      const TString& theOption = "",
00065                      TDirectory* theTargetDir = 0 );
00066 
00067       MethodRuleFit( DataSetInfo& theData,
00068                      const TString& theWeightFile,
00069                      TDirectory* theTargetDir = NULL );
00070 
00071       virtual ~MethodRuleFit( void );
00072 
00073       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t  );
00074 
00075       
00076       void Train( void );
00077 
00078       using MethodBase::ReadWeightsFromStream;
00079 
00080       
00081       void AddWeightsXMLTo     ( void* parent ) const;
00082 
00083       
00084       void ReadWeightsFromStream( istream& istr );
00085       void ReadWeightsFromXML   ( void* wghtnode );
00086 
00087       
00088       Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
00089 
00090       
00091       void WriteMonitoringHistosToFile( void ) const;
00092 
00093       
00094       const Ranking* CreateRanking();
00095 
00096       Bool_t                                   UseBoost()           const   { return fUseBoost; }
00097 
00098       
00099       RuleFit*                                 GetRuleFitPtr()              { return &fRuleFit; }
00100       const RuleFit*                           GetRuleFitConstPtr() const   { return &fRuleFit; }
00101       TDirectory*                              GetMethodBaseDir()   const   { return BaseDir(); }
00102       const std::vector<TMVA::Event*>&         GetTrainingEvents()  const   { return fEventSample; }
00103       const std::vector<TMVA::DecisionTree*>&  GetForest()          const   { return fForest; }
00104       Int_t                                    GetNTrees()          const   { return fNTrees; }
00105       Double_t                                 GetTreeEveFrac()     const   { return fTreeEveFrac; }
00106       const SeparationBase*                    GetSeparationBaseConst() const { return fSepType; }
00107       SeparationBase*                          GetSeparationBase()  const   { return fSepType; }
00108       TMVA::DecisionTree::EPruneMethod         GetPruneMethod()     const   { return fPruneMethod; }
00109       Double_t                                 GetPruneStrength()   const   { return fPruneStrength; }
00110       Double_t                                 GetMinFracNEve()     const   { return fMinFracNEve; }
00111       Double_t                                 GetMaxFracNEve()     const   { return fMaxFracNEve; }
00112       Int_t                                    GetNCuts()           const   { return fNCuts; }
00113       
00114       Int_t                                    GetGDNPathSteps()    const   { return fGDNPathSteps; }
00115       Double_t                                 GetGDPathStep()      const   { return fGDPathStep; }
00116       Double_t                                 GetGDErrScale()      const   { return fGDErrScale; }
00117       Double_t                                 GetGDPathEveFrac()   const   { return fGDPathEveFrac; }
00118       Double_t                                 GetGDValidEveFrac()  const   { return fGDValidEveFrac; }
00119       
00120       Double_t                                 GetLinQuantile()     const   { return fLinQuantile; }
00121 
00122       const TString                            GetRFWorkDir()       const   { return fRFWorkDir; }
00123       Int_t                                    GetRFNrules()        const   { return fRFNrules; }
00124       Int_t                                    GetRFNendnodes()     const   { return fRFNendnodes; }
00125 
00126    protected:
00127 
00128       
00129       void MakeClassSpecific( std::ostream&, const TString& ) const;
00130 
00131       void MakeClassRuleCuts( std::ostream& ) const;
00132 
00133       void MakeClassLinear( std::ostream& ) const;
00134 
00135       
00136       void GetHelpMessage() const;
00137 
00138       
00139       void Init( void );
00140 
00141       
00142       void InitEventSample( void );
00143 
00144       
00145       void InitMonitorNtuple();
00146 
00147       void TrainTMVARuleFit();
00148       void TrainJFRuleFit();
00149 
00150    private:
00151 
00152       
00153       template<typename T>
00154       inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
00155 
00156       template<typename T>
00157       inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
00158 
00159       template<typename T>
00160       inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
00161 
00162       
00163       void DeclareOptions();
00164       void ProcessOptions();
00165 
00166       RuleFit                      fRuleFit;        
00167       std::vector<TMVA::Event *>   fEventSample;    
00168       Double_t                     fSignalFraction; 
00169 
00170       
00171       TTree                       *fMonitorNtuple;  
00172       Double_t                     fNTImportance;   
00173       Double_t                     fNTCoefficient;  
00174       Double_t                     fNTSupport;      
00175       Int_t                        fNTNcuts;        
00176       Int_t                        fNTNvars;        
00177       Double_t                     fNTPtag;         
00178       Double_t                     fNTPss;          
00179       Double_t                     fNTPsb;          
00180       Double_t                     fNTPbs;          
00181       Double_t                     fNTPbb;          
00182       Double_t                     fNTSSB;          
00183       Int_t                        fNTType;         
00184 
00185       
00186       TString                      fRuleFitModuleS;
00187       Bool_t                       fUseRuleFitJF;  
00188       TString                      fRFWorkDir;     
00189       Int_t                        fRFNrules;      
00190       Int_t                        fRFNendnodes;   
00191       std::vector<DecisionTree *>  fForest;        
00192       Int_t                        fNTrees;        
00193       Double_t                     fTreeEveFrac;   
00194       SeparationBase              *fSepType;       
00195       Double_t                     fMinFracNEve;   
00196       Double_t                     fMaxFracNEve;   
00197       Int_t                        fNCuts;         
00198       TString                      fSepTypeS;        
00199       TString                      fPruneMethodS;    
00200       TMVA::DecisionTree::EPruneMethod fPruneMethod; 
00201       Double_t                     fPruneStrength;   
00202       TString                      fForestTypeS;     
00203       Bool_t                       fUseBoost;        
00204       
00205       Double_t                     fGDPathEveFrac; 
00206       Double_t                     fGDValidEveFrac; 
00207       Double_t                     fGDTau;          
00208       Double_t                     fGDTauPrec;      
00209       Double_t                     fGDTauMin;       
00210       Double_t                     fGDTauMax;       
00211       UInt_t                       fGDTauScan;      
00212       Double_t                     fGDPathStep;     
00213       Int_t                        fGDNPathSteps;   
00214       Double_t                     fGDErrScale;     
00215       Double_t                     fMinimp;         
00216       
00217       TString                      fModelTypeS;     
00218       Double_t                     fRuleMinDist;    
00219       Double_t                     fLinQuantile;    
00220 
00221       ClassDef(MethodRuleFit,0)  
00222    };
00223 
00224 } 
00225 
00226 
00227 
00228 template<typename T>
00229 inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
00230 {
00231    
00232    if (var>vmax) return  1;
00233    if (var<vmin) return -1;
00234    return 0;
00235 }
00236 
00237 
00238 template<typename T>
00239 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
00240 {
00241    
00242    
00243    Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
00244    Bool_t modif=kFALSE;
00245    if (dir==1) {
00246       modif = kTRUE;
00247       var=vmax;
00248    }
00249    if (dir==-1) {
00250       modif = kTRUE;
00251       var=vmin;
00252    }
00253    if (modif) {
00254       mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
00255    }
00256    return modif;
00257 }
00258 
00259 
00260 template<typename T>
00261 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
00262 {
00263    
00264    
00265    Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
00266    Bool_t modif=kFALSE;
00267    if (dir!=0) {
00268       modif = kTRUE;
00269       var=vdef;
00270    }
00271    if (modif) {
00272       mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
00273    }
00274    return modif;
00275 }
00276 
00277 
00278 #endif // MethodRuleFit_H