MethodRuleFit.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: MethodRuleFit.h 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Fredrik Tegenfeldt
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : MethodRuleFit                                                         *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Friedman's RuleFit method                                                 * 
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
00015  *                                                                                *
00016  * Copyright (c) 2005:                                                            *
00017  *      CERN, Switzerland                                                         * 
00018  *      Iowa State U.                                                             *
00019  *      MPI-K Heidelberg, Germany                                                 * 
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  *                                                                                *
00024  **********************************************************************************/
00025 
00026 #ifndef ROOT_TMVA_MethodRuleFit
00027 #define ROOT_TMVA_MethodRuleFit
00028 
00029 //////////////////////////////////////////////////////////////////////////
00030 //                                                                      //
00031 // MethodRuleFit                                                        //
00032 //                                                                      //
00033 // J Friedman's RuleFit method                                          //
00034 //                                                                      //
00035 //////////////////////////////////////////////////////////////////////////
00036 
00037 #ifndef ROOT_TMVA_MethodBase
00038 #include "TMVA/MethodBase.h"
00039 #endif
00040 #ifndef ROOT_TMatrixDfwd
00041 #include "TMatrixDfwd.h"
00042 #endif
00043 #ifndef ROOT_TVectorD
00044 #include "TVectorD.h"
00045 #endif
00046 #ifndef ROOT_TMVA_DecisionTree
00047 #include "TMVA/DecisionTree.h"
00048 #endif
00049 #ifndef ROOT_TMVA_RuleFit
00050 #include "TMVA/RuleFit.h"
00051 #endif
00052 
00053 namespace TMVA {
00054 
00055    class SeparationBase;
00056 
00057    class MethodRuleFit : public MethodBase {
00058 
00059    public:
00060 
00061       MethodRuleFit( const TString& jobName,
00062                      const TString& methodTitle, 
00063                      DataSetInfo& theData,
00064                      const TString& theOption = "",
00065                      TDirectory* theTargetDir = 0 );
00066 
00067       MethodRuleFit( DataSetInfo& theData,
00068                      const TString& theWeightFile,
00069                      TDirectory* theTargetDir = NULL );
00070 
00071       virtual ~MethodRuleFit( void );
00072 
00073       virtual Bool_t HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t /*numberTargets*/ );
00074 
00075       // training method
00076       void Train( void );
00077 
00078       using MethodBase::ReadWeightsFromStream;
00079 
00080       // write weights to file
00081       void AddWeightsXMLTo     ( void* parent ) const;
00082 
00083       // read weights from file
00084       void ReadWeightsFromStream( istream& istr );
00085       void ReadWeightsFromXML   ( void* wghtnode );
00086 
00087       // calculate the MVA value
00088       Double_t GetMvaValue( Double_t* err = 0, Double_t* errUpper = 0 );
00089 
00090       // write method specific histos to target file
00091       void WriteMonitoringHistosToFile( void ) const;
00092 
00093       // ranking of input variables
00094       const Ranking* CreateRanking();
00095 
00096       Bool_t                                   UseBoost()           const   { return fUseBoost; }
00097 
00098       // accessors
00099       RuleFit*                                 GetRuleFitPtr()              { return &fRuleFit; }
00100       const RuleFit*                           GetRuleFitConstPtr() const   { return &fRuleFit; }
00101       TDirectory*                              GetMethodBaseDir()   const   { return BaseDir(); }
00102       const std::vector<TMVA::Event*>&         GetTrainingEvents()  const   { return fEventSample; }
00103       const std::vector<TMVA::DecisionTree*>&  GetForest()          const   { return fForest; }
00104       Int_t                                    GetNTrees()          const   { return fNTrees; }
00105       Double_t                                 GetTreeEveFrac()     const   { return fTreeEveFrac; }
00106       const SeparationBase*                    GetSeparationBaseConst() const { return fSepType; }
00107       SeparationBase*                          GetSeparationBase()  const   { return fSepType; }
00108       TMVA::DecisionTree::EPruneMethod         GetPruneMethod()     const   { return fPruneMethod; }
00109       Double_t                                 GetPruneStrength()   const   { return fPruneStrength; }
00110       Double_t                                 GetMinFracNEve()     const   { return fMinFracNEve; }
00111       Double_t                                 GetMaxFracNEve()     const   { return fMaxFracNEve; }
00112       Int_t                                    GetNCuts()           const   { return fNCuts; }
00113       //
00114       Int_t                                    GetGDNPathSteps()    const   { return fGDNPathSteps; }
00115       Double_t                                 GetGDPathStep()      const   { return fGDPathStep; }
00116       Double_t                                 GetGDErrScale()      const   { return fGDErrScale; }
00117       Double_t                                 GetGDPathEveFrac()   const   { return fGDPathEveFrac; }
00118       Double_t                                 GetGDValidEveFrac()  const   { return fGDValidEveFrac; }
00119       //
00120       Double_t                                 GetLinQuantile()     const   { return fLinQuantile; }
00121 
00122       const TString                            GetRFWorkDir()       const   { return fRFWorkDir; }
00123       Int_t                                    GetRFNrules()        const   { return fRFNrules; }
00124       Int_t                                    GetRFNendnodes()     const   { return fRFNendnodes; }
00125 
00126    protected:
00127 
00128       // make ROOT-independent C++ class for classifier response (classifier-specific implementation)
00129       void MakeClassSpecific( std::ostream&, const TString& ) const;
00130 
00131       void MakeClassRuleCuts( std::ostream& ) const;
00132 
00133       void MakeClassLinear( std::ostream& ) const;
00134 
00135       // get help message text
00136       void GetHelpMessage() const;
00137 
00138       // initialize rulefit
00139       void Init( void );
00140 
00141       // copy all training events into a stl::vector
00142       void InitEventSample( void );
00143 
00144       // initialize monitor ntuple
00145       void InitMonitorNtuple();
00146 
00147       void TrainTMVARuleFit();
00148       void TrainJFRuleFit();
00149 
00150    private:
00151 
00152       // check variable range and set var to lower or upper if out of range
00153       template<typename T>
00154       inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax );
00155 
00156       template<typename T>
00157       inline Bool_t VerifyRange( MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef );
00158 
00159       template<typename T>
00160       inline Int_t VerifyRange( const T& var, const T& vmin, const T& vmax );
00161 
00162       // the option handling methods
00163       void DeclareOptions();
00164       void ProcessOptions();
00165 
00166       RuleFit                      fRuleFit;        // RuleFit instance
00167       std::vector<TMVA::Event *>   fEventSample;    // the complete training sample
00168       Double_t                     fSignalFraction; // scalefactor for bkg events to modify initial s/b fraction in training data
00169 
00170       // ntuple
00171       TTree                       *fMonitorNtuple;  // pointer to monitor rule ntuple
00172       Double_t                     fNTImportance;   // ntuple: rule importance
00173       Double_t                     fNTCoefficient;  // ntuple: rule coefficient
00174       Double_t                     fNTSupport;      // ntuple: rule support
00175       Int_t                        fNTNcuts;        // ntuple: rule number of cuts
00176       Int_t                        fNTNvars;        // ntuple: rule number of vars
00177       Double_t                     fNTPtag;         // ntuple: rule P(tag)
00178       Double_t                     fNTPss;          // ntuple: rule P(tag s, true s)
00179       Double_t                     fNTPsb;          // ntuple: rule P(tag s, true b)
00180       Double_t                     fNTPbs;          // ntuple: rule P(tag b, true s)
00181       Double_t                     fNTPbb;          // ntuple: rule P(tag b, true b)
00182       Double_t                     fNTSSB;          // ntuple: rule S/(S+B)
00183       Int_t                        fNTType;         // ntuple: rule type (+1->signal, -1->bkg)
00184 
00185       // options
00186       TString                      fRuleFitModuleS;// which rulefit module to use
00187       Bool_t                       fUseRuleFitJF;  // if true interface with J.Friedmans RuleFit module
00188       TString                      fRFWorkDir;     // working directory from Friedmans module
00189       Int_t                        fRFNrules;      // max number of rules (only Friedmans module)
00190       Int_t                        fRFNendnodes;   // max number of rules (only Friedmans module)
00191       std::vector<DecisionTree *>  fForest;        // the forest
00192       Int_t                        fNTrees;        // number of trees in forest
00193       Double_t                     fTreeEveFrac;   // fraction of events used for traing each tree
00194       SeparationBase              *fSepType;       // the separation used in node splitting
00195       Double_t                     fMinFracNEve;   // min fraction of number events
00196       Double_t                     fMaxFracNEve;   // ditto max
00197       Int_t                        fNCuts;         // grid used in cut applied in node splitting
00198       TString                      fSepTypeS;        // forest generation: separation type - see DecisionTree
00199       TString                      fPruneMethodS;    // forest generation: prune method - see DecisionTree
00200       TMVA::DecisionTree::EPruneMethod fPruneMethod; // forest generation: method used for pruning - see DecisionTree 
00201       Double_t                     fPruneStrength;   // forest generation: prune strength - see DecisionTree
00202       TString                      fForestTypeS;     // forest generation: how the trees are generated
00203       Bool_t                       fUseBoost;        // use boosted events for forest generation
00204       //
00205       Double_t                     fGDPathEveFrac; //  GD path: fraction of subsamples used for the fitting
00206       Double_t                     fGDValidEveFrac; // GD path: fraction of subsamples used for the fitting
00207       Double_t                     fGDTau;          // GD path: def threshhold fraction [0..1]
00208       Double_t                     fGDTauPrec;      // GD path: precision of estimated tau
00209       Double_t                     fGDTauMin;       // GD path: min threshhold fraction [0..1]
00210       Double_t                     fGDTauMax;       // GD path: max threshhold fraction [0..1]
00211       UInt_t                       fGDTauScan;      // GD path: number of points to scan
00212       Double_t                     fGDPathStep;     // GD path: step size in path
00213       Int_t                        fGDNPathSteps;   // GD path: number of steps
00214       Double_t                     fGDErrScale;     // GD path: stop 
00215       Double_t                     fMinimp;         // rule/linear: minimum importance
00216       //
00217       TString                      fModelTypeS;     // rule ensemble: which model (rule,linear or both)
00218       Double_t                     fRuleMinDist;    // rule min distance - see RuleEnsemble
00219       Double_t                     fLinQuantile;    // quantile cut to remove outliers - see RuleEnsemble
00220 
00221       ClassDef(MethodRuleFit,0)  // Friedman's RuleFit method
00222    };
00223 
00224 } // namespace TMVA
00225 
00226 
00227 //_______________________________________________________________________
00228 template<typename T>
00229 inline Int_t TMVA::MethodRuleFit::VerifyRange( const T& var, const T& vmin, const T& vmax )
00230 {
00231    // check range and return +1 if above, -1 if below or 0 if inside
00232    if (var>vmax) return  1;
00233    if (var<vmin) return -1;
00234    return 0;
00235 }
00236 
00237 //_______________________________________________________________________
00238 template<typename T>
00239 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax )
00240 {
00241    // verify range and print out message
00242    // if outside range, set to closest limit
00243    Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
00244    Bool_t modif=kFALSE;
00245    if (dir==1) {
00246       modif = kTRUE;
00247       var=vmax;
00248    }
00249    if (dir==-1) {
00250       modif = kTRUE;
00251       var=vmin;
00252    }
00253    if (modif) {
00254       mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to new value = " << var << Endl;
00255    }
00256    return modif;
00257 }
00258 
00259 //_______________________________________________________________________
00260 template<typename T>
00261 inline Bool_t TMVA::MethodRuleFit::VerifyRange( TMVA::MsgLogger& mlog, const char *varstr, T& var, const T& vmin, const T& vmax, const T& vdef )
00262 {
00263    // verify range and print out message
00264    // if outside range, set to given default value
00265    Int_t dir = TMVA::MethodRuleFit::VerifyRange(var,vmin,vmax);
00266    Bool_t modif=kFALSE;
00267    if (dir!=0) {
00268       modif = kTRUE;
00269       var=vdef;
00270    }
00271    if (modif) {
00272       mlog << kWARNING << "Option <" << varstr << "> " << (dir==1 ? "above":"below") << " allowed range. Reset to default value = " << var << Endl;
00273    }
00274    return modif;
00275 }
00276 
00277 
00278 #endif // MethodRuleFit_H

Generated on Tue Jul 5 14:27:32 2011 for ROOT_528-00b_version by  doxygen 1.5.1