RuleFit.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: RuleFit.h 29195 2009-06-24 10:39:49Z brun $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : RuleFit                                                               *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      A class implementing various fits of rule ensembles                       *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
00015  *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, Ger. *
00016  *                                                                                *
00017  * Copyright (c) 2005:                                                            *
00018  *      CERN, Switzerland                                                         *
00019  *      Iowa State U.                                                             *
00020  *      MPI-K Heidelberg, Germany                                                 *
00021  *                                                                                *
00022  * Redistribution and use in source and binary forms, with or without             *
00023  * modification, are permitted according to the terms listed in LICENSE           *
00024  * (http://tmva.sourceforge.net/LICENSE)                                          *
00025  **********************************************************************************/
00026 
00027 #ifndef ROOT_TMVA_RuleFit
00028 #define ROOT_TMVA_RuleFit
00029 
00030 #include <algorithm>
00031 
00032 #ifndef ROOT_TMVA_DecisionTree
00033 #include "TMVA/DecisionTree.h"
00034 #endif
00035 #ifndef ROOT_TMVA_RuleEnsemble
00036 #include "TMVA/RuleEnsemble.h"
00037 #endif
00038 #ifndef ROOT_TMVA_RuleFitParams
00039 #include "TMVA/RuleFitParams.h"
00040 #endif
00041 #ifndef ROOT_TMVA_Event
00042 #include "TMVA/Event.h"
00043 #endif
00044 
00045 namespace TMVA {
00046 
00047 
00048    class MethodBase;
00049    class MethodRuleFit;
00050    class MsgLogger;
00051 
00052    class RuleFit {
00053 
00054    public:
00055 
00056       // main constructor
00057       RuleFit( const TMVA::MethodBase *rfbase );
00058 
00059       // empty constructor
00060       RuleFit( void );
00061 
00062       virtual ~RuleFit( void );
00063 
00064       void InitNEveEff();
00065       void InitPtrs( const TMVA::MethodBase *rfbase );
00066       void Initialize(  const TMVA::MethodBase *rfbase );
00067 
00068       void SetMsgType( EMsgType t );
00069 
00070       void SetTrainingEvents( const std::vector<TMVA::Event *> & el );
00071 
00072       void ReshuffleEvents() { std::random_shuffle(fTrainingEventsRndm.begin(),fTrainingEventsRndm.end()); }
00073 
00074       void SetMethodBase( const MethodBase *rfbase );
00075 
00076       // make the forest of trees for rule generation
00077       void MakeForest();
00078 
00079       // build a tree
00080       void BuildTree( TMVA::DecisionTree *dt );
00081 
00082       // save event weights
00083       void SaveEventWeights();
00084 
00085       // restore saved event weights
00086       void RestoreEventWeights();
00087 
00088       // boost events based on the given tree
00089       void Boost( TMVA::DecisionTree *dt );
00090 
00091       // calculate and print some statistics on the given forest
00092       void ForestStatistics();
00093 
00094       // calculate the discriminating variable for the given event
00095       Double_t EvalEvent( const Event& e );
00096 
00097       // calculate sum of 
00098       Double_t CalcWeightSum( const std::vector<TMVA::Event *> *events, UInt_t neve=0 );
00099 
00100       // do the fitting of the coefficients
00101       void     FitCoefficients();
00102 
00103       // calculate variable and rule importance from a set of events
00104       void     CalcImportance();
00105 
00106       // set usage of linear term
00107       void     SetModelLinear()                      { fRuleEnsemble.SetModelLinear(); }
00108       // set usage of rules
00109       void     SetModelRules()                       { fRuleEnsemble.SetModelRules(); }
00110       // set usage of linear term
00111       void     SetModelFull()                        { fRuleEnsemble.SetModelFull(); }
00112       // set minimum importance allowed
00113       void     SetImportanceCut( Double_t minimp=0 ) { fRuleEnsemble.SetImportanceCut(minimp); }
00114       // set minimum rule distance - see RuleEnsemble
00115       void     SetRuleMinDist( Double_t d )          { fRuleEnsemble.SetRuleMinDist(d); }
00116       // set path related parameters
00117       void     SetGDTau( Double_t t=0.0 )       { fRuleFitParams.SetGDTau(t); }
00118       void     SetGDPathStep( Double_t s=0.01 ) { fRuleFitParams.SetGDPathStep(s); }
00119       void     SetGDNPathSteps( Int_t n=100 )   { fRuleFitParams.SetGDNPathSteps(n); }
00120       // make visualization histograms
00121       void     SetVisHistsUseImp( Bool_t f ) { fVisHistsUseImp = f; }
00122       void     UseImportanceVisHists()       { fVisHistsUseImp = kTRUE; }
00123       void     UseCoefficientsVisHists()     { fVisHistsUseImp = kFALSE; }
00124       void     MakeVisHists();
00125       void     FillVisHistCut(const Rule * rule, std::vector<TH2F *> & hlist);
00126       void     FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist);
00127       void     FillCut(TH2F* h2,const TMVA::Rule *rule,Int_t vind);
00128       void     FillLin(TH2F* h2,Int_t vind);
00129       void     FillCorr(TH2F* h2,const TMVA::Rule *rule,Int_t v1, Int_t v2);
00130       void     NormVisHists(std::vector<TH2F *> & hlist);
00131       void     MakeDebugHists();
00132       Bool_t   GetCorrVars(TString & title, TString & var1, TString & var2);
00133       // accessors
00134       UInt_t        GetNTreeSample()            const { return fNTreeSample; }
00135       Double_t      GetNEveEff()                const { return fNEveEffTrain; } // reweighted number of events = sum(wi)
00136       const Event*  GetTrainingEvent(UInt_t i)  const { return static_cast< const Event *>(fTrainingEvents[i]); }
00137       Double_t      GetTrainingEventWeight(UInt_t i)  const { return fTrainingEvents[i]->GetWeight(); }
00138 
00139       //      const Event*  GetTrainingEvent(UInt_t i, UInt_t isub)  const { return &(fTrainingEvents[fSubsampleEvents[isub]])[i]; }
00140 
00141       const std::vector< TMVA::Event * > & GetTrainingEvents()  const { return fTrainingEvents; }
00142       //      const std::vector< Int_t >               & GetSubsampleEvents() const { return fSubsampleEvents; }
00143 
00144       //      void  GetSubsampleEvents(Int_t sub, UInt_t & ibeg, UInt_t & iend) const;
00145       void  GetRndmSampleEvents(std::vector< const TMVA::Event * > & evevec, UInt_t nevents);
00146       //
00147       const std::vector< const TMVA::DecisionTree *> & GetForest()     const { return fForest; }
00148       const RuleEnsemble                       & GetRuleEnsemble()     const { return fRuleEnsemble; }
00149             RuleEnsemble                       * GetRuleEnsemblePtr()        { return &fRuleEnsemble; }
00150       const RuleFitParams                      & GetRuleFitParams()    const { return fRuleFitParams; }
00151             RuleFitParams                      * GetRuleFitParamsPtr()       { return &fRuleFitParams; }
00152       const MethodRuleFit                      * GetMethodRuleFit()    const { return fMethodRuleFit; }
00153       const MethodBase                         * GetMethodBase()       const { return fMethodBase; }
00154 
00155    private:
00156 
00157       // copy constructor
00158       RuleFit( const RuleFit & other );
00159 
00160       // copy method
00161       void Copy( const RuleFit & other );
00162 
00163       std::vector<TMVA::Event *>          fTrainingEvents;      // all training events
00164       std::vector<TMVA::Event *>          fTrainingEventsRndm;  // idem, but randomly shuffled
00165       std::vector<Double_t>               fEventWeights;        // original weights of the events - follows fTrainingEvents
00166       UInt_t                              fNTreeSample;         // number of events in sub sample = frac*neve
00167 
00168       Double_t                            fNEveEffTrain;    // reweighted number of events = sum(wi)
00169       std::vector< const TMVA::DecisionTree *>  fForest;    // the input forest of decision trees
00170       RuleEnsemble                        fRuleEnsemble;    // the ensemble of rules
00171       RuleFitParams                       fRuleFitParams;   // fit rule parameters
00172       const MethodRuleFit                *fMethodRuleFit;   // pointer the method which initialized this RuleFit instance
00173       const MethodBase                   *fMethodBase;      // pointer the method base which initialized this RuleFit instance
00174       Bool_t                              fVisHistsUseImp;  // if true, use importance as weight; else coef in vis hists
00175 
00176       mutable MsgLogger*                  fLogger;   // message logger
00177       MsgLogger& Log() const { return *fLogger; }    
00178 
00179       static const Int_t randSEED = 0; // set to 1 for debugging purposes or to zero for random seeds
00180 
00181       ClassDef(RuleFit,0)  // Calculations for Friedman's RuleFit method
00182    };
00183 }
00184 
00185 #endif

Generated on Tue Jul 5 14:27:38 2011 for ROOT_528-00b_version by  doxygen 1.5.1