00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #ifndef ROOT_TMVA_RuleEnsemble
00030 #define ROOT_TMVA_RuleEnsemble
00031
00032 #if ROOT_VERSION_CODE >= 364802
00033 #ifndef ROOT_TMathBase
00034 #include "TMathBase.h"
00035 #endif
00036 #else
00037 #ifndef ROOT_TMath
00038 #include "TMath.h"
00039 #endif
00040 #endif
00041
00042 #ifndef ROOT_TMVA_DecisionTree
00043 #include "TMVA/DecisionTree.h"
00044 #endif
00045 #ifndef ROOT_TMVA_Event
00046 #include "TMVA/Event.h"
00047 #endif
00048 #ifndef ROOT_TMVA_Rule
00049 #include "TMVA/Rule.h"
00050 #endif
00051 #ifndef ROOT_TMVA_Types
00052 #include "TMVA/Types.h"
00053 #endif
00054
00055 class TH1F;
00056
00057 namespace TMVA {
00058
00059 class TBits;
00060 class MethodBase;
00061 class RuleFit;
00062 class MethodRuleFit;
00063 class RuleEnsemble;
00064 class MsgLogger;
00065
00066 ostream& operator<<( ostream& os, const RuleEnsemble& event );
00067
00068 class RuleEnsemble {
00069
00070
00071 friend ostream& operator<< ( ostream& os, const RuleEnsemble& rules );
00072
00073 public:
00074
00075 enum ELearningModel { kFull=0, kRules=1, kLinear=2 };
00076
00077
00078 RuleEnsemble( RuleFit* rf );
00079
00080
00081 RuleEnsemble( const RuleEnsemble& other );
00082
00083
00084 RuleEnsemble();
00085
00086
00087 virtual ~RuleEnsemble();
00088
00089
00090 void Initialize( const RuleFit* rf );
00091
00092
00093 void SetMsgType( EMsgType t );
00094
00095
00096 void MakeModel();
00097
00098
00099 void MakeRules( const std::vector< const TMVA::DecisionTree *>& forest );
00100
00101
00102 void MakeLinearTerms();
00103
00104
00105 void SetModelLinear() { fLearningModel = kLinear; }
00106
00107
00108 void SetModelRules() { fLearningModel = kRules; }
00109
00110
00111 void SetModelFull() { fLearningModel = kFull; }
00112
00113
00114 void SetRules( const std::vector< TMVA::Rule *> & rules );
00115
00116
00117 void SetRuleFit( const RuleFit *rf ) { fRuleFit = rf; }
00118
00119
00120 void SetCoefficients( const std::vector< Double_t >& v );
00121 void SetCoefficient( UInt_t i, Double_t v ) { if (i<fRules.size()) fRules[i]->SetCoefficient(v); }
00122
00123 void SetOffset(Double_t v=0.0) { fOffset=v; }
00124 void AddOffset(Double_t v) { fOffset+=v; }
00125 void SetLinCoefficients( const std::vector< Double_t >& v ) { fLinCoefficients = v; }
00126 void SetLinCoefficient( UInt_t i, Double_t v ) { fLinCoefficients[i] = v; }
00127 void SetLinDM( const std::vector<Double_t> & xmin ) { fLinDM = xmin; }
00128 void SetLinDP( const std::vector<Double_t> & xmax ) { fLinDP = xmax; }
00129 void SetLinNorm( const std::vector<Double_t> & norm ) { fLinNorm = norm; }
00130
00131 Double_t CalcLinNorm( Double_t stdev ) { return ( stdev>0 ? fAverageRuleSigma/stdev : 1.0 ); }
00132
00133
00134 void ClearCoefficients( Double_t val=0 ) { for (UInt_t i=0; i<fRules.size(); i++) fRules[i]->SetCoefficient(val); }
00135 void ClearLinCoefficients( Double_t val=0 ) { for (UInt_t i=0; i<fLinCoefficients.size(); i++) fLinCoefficients[i]=val; }
00136 void ClearLinNorm( Double_t val=1.0 ) { for (UInt_t i=0; i<fLinNorm.size(); i++) fLinNorm[i]=val; }
00137
00138
00139 void SetRuleMinDist(Double_t d) { fRuleMinDist = d; }
00140
00141
00142 void SetImportanceCut(Double_t minimp=0) { fImportanceCut=minimp; }
00143
00144
00145 void SetLinQuantile(Double_t q) { fLinQuantile=q; }
00146
00147
00148 void SetAverageRuleSigma(Double_t v) { if (v>0.5) v=0.5; fAverageRuleSigma = v; fAverageSupport = 0.5*(1.0+TMath::Sqrt(1.0-4.0*v*v)); }
00149
00150
00151 Int_t CalcNRules( const TMVA::DecisionTree* dtree );
00152
00153 void FindNEndNodes( const TMVA::Node* node, Int_t& nendnodes );
00154
00155
00156 void SetEvent( const Event & e ) { fEvent = &e; fEventCacheOK = kFALSE; }
00157
00158
00159 void UpdateEventVal();
00160
00161
00162 void MakeRuleMap(const std::vector<TMVA::Event *> *events=0, UInt_t ifirst=0, UInt_t ilast=0);
00163
00164
00165 void ClearRuleMap() { fRuleMap.clear(); fRuleMapEvents=0; }
00166
00167
00168
00169 Double_t EvalEvent() const;
00170 Double_t EvalEvent( const Event & e );
00171
00172
00173 Double_t EvalEvent( Double_t ofs,
00174 const std::vector<Double_t> & coefs,
00175 const std::vector<Double_t> & lincoefs) const;
00176 Double_t EvalEvent( const Event & e,
00177 Double_t ofs,
00178 const std::vector<Double_t> & coefs,
00179 const std::vector<Double_t> & lincoefs);
00180
00181
00182
00183 Double_t EvalEvent( UInt_t evtidx ) const;
00184 Double_t EvalEvent( UInt_t evtidx,
00185 Double_t ofs,
00186 const std::vector<Double_t> & coefs,
00187 const std::vector<Double_t> & lincoefs) const;
00188
00189
00190
00191 Double_t EvalLinEvent() const;
00192 Double_t EvalLinEvent( const std::vector<Double_t> & coefs ) const;
00193 Double_t EvalLinEvent( const Event &e );
00194 Double_t EvalLinEvent( const Event &e, UInt_t vind );
00195 Double_t EvalLinEvent( const Event &e, const std::vector<Double_t> & coefs );
00196
00197
00198 Double_t EvalLinEvent( UInt_t evtidx ) const;
00199 Double_t EvalLinEvent( UInt_t evtidx, const std::vector<Double_t> & coefs ) const;
00200 Double_t EvalLinEvent( UInt_t evtidx, UInt_t vind ) const;
00201 Double_t EvalLinEvent( UInt_t evtidx, UInt_t vind, Double_t coefs ) const;
00202
00203
00204 Double_t EvalLinEventRaw( UInt_t vind, const Event &e, Bool_t norm ) const;
00205 Double_t EvalLinEventRaw( UInt_t vind, UInt_t evtidx, Bool_t norm ) const;
00206
00207
00208 Double_t PdfLinear( Double_t & nsig, Double_t & ntot ) const;
00209
00210
00211 Double_t PdfRule( Double_t & nsig, Double_t & ntot ) const;
00212
00213
00214 Double_t FStar() const;
00215 Double_t FStar(const TMVA::Event & e );
00216
00217
00218 void SetImportanceRef(Double_t impref);
00219
00220
00221 void CalcRuleSupport();
00222
00223
00224 void CalcImportance();
00225
00226
00227 Double_t CalcRuleImportance();
00228
00229
00230 Double_t CalcLinImportance();
00231
00232
00233 void CalcVarImportance();
00234
00235
00236 void CleanupRules();
00237
00238
00239 void CleanupLinear();
00240
00241
00242 void RemoveSimilarRules();
00243
00244
00245 void RuleStatistics();
00246
00247
00248 void RuleResponseStats();
00249
00250
00251 void operator=( const RuleEnsemble& other ) { Copy( other ); }
00252
00253
00254 Double_t CoefficientRadius();
00255
00256
00257 void GetCoefficients( std::vector< Double_t >& v );
00258
00259
00260 const MethodRuleFit* GetMethodRuleFit() const;
00261 const MethodBase* GetMethodBase() const;
00262 const RuleFit* GetRuleFit() const { return fRuleFit; }
00263
00264 const std::vector<TMVA::Event *>* GetTrainingEvents() const;
00265 const Event* GetTrainingEvent(UInt_t i) const;
00266 const Event* GetEvent() const { return fEvent; }
00267
00268 Bool_t DoLinear() const { return (fLearningModel==kFull) || (fLearningModel==kLinear); }
00269 Bool_t DoRules() const { return (fLearningModel==kFull) || (fLearningModel==kRules); }
00270 Bool_t DoOnlyRules() const { return (fLearningModel==kRules); }
00271 Bool_t DoOnlyLinear() const { return (fLearningModel==kLinear); }
00272 Bool_t DoFull() const { return (fLearningModel==kFull); }
00273 ELearningModel GetLearningModel() const { return fLearningModel; }
00274 Double_t GetImportanceCut() const { return fImportanceCut; }
00275 Double_t GetImportanceRef() const { return fImportanceRef; }
00276 Double_t GetOffset() const { return fOffset; }
00277 UInt_t GetNRules() const { return (DoRules() ? fRules.size():0); }
00278 const std::vector<TMVA::Rule*>& GetRulesConst() const { return fRules; }
00279 std::vector<TMVA::Rule*>& GetRules() { return fRules; }
00280 const std::vector< Double_t >& GetLinCoefficients() const { return fLinCoefficients; }
00281 const std::vector< Double_t >& GetLinNorm() const { return fLinNorm; }
00282 const std::vector< Double_t >& GetLinImportance() const { return fLinImportance; }
00283 const std::vector< Double_t >& GetVarImportance() const { return fVarImportance; }
00284 UInt_t GetNLinear() const { return (DoLinear() ? fLinNorm.size():0); }
00285 Double_t GetLinQuantile() const { return fLinQuantile; }
00286
00287 const Rule *GetRulesConst(int i) const { return fRules[i]; }
00288 Rule *GetRules(int i) { return fRules[i]; }
00289
00290 UInt_t GetRulesNCuts(int i) const { return fRules[i]->GetRuleCut()->GetNcuts(); }
00291 Double_t GetRuleMinDist() const { return fRuleMinDist; }
00292 Double_t GetLinCoefficients(int i) const { return fLinCoefficients[i]; }
00293 Double_t GetLinNorm(int i) const { return fLinNorm[i]; }
00294 Double_t GetLinDM(int i) const { return fLinDM[i]; }
00295 Double_t GetLinDP(int i) const { return fLinDP[i]; }
00296 Double_t GetLinImportance(int i) const { return fLinImportance[i]; }
00297 Double_t GetVarImportance(int i) const { return fVarImportance[i]; }
00298 Double_t GetRulePTag(int i) const { return fRulePTag[i]; }
00299 Double_t GetRulePSS(int i) const { return fRulePSS[i]; }
00300 Double_t GetRulePSB(int i) const { return fRulePSB[i]; }
00301 Double_t GetRulePBS(int i) const { return fRulePBS[i]; }
00302 Double_t GetRulePBB(int i) const { return fRulePBB[i]; }
00303
00304 Bool_t IsLinTermOK(int i) const { return fLinTermOK[i]; }
00305
00306 Double_t GetAverageSupport() const { return fAverageSupport; }
00307 Double_t GetAverageRuleSigma() const { return fAverageRuleSigma; }
00308 Double_t GetEventRuleVal(UInt_t i) const { return (fEventRuleVal[i] ? 1.0:0.0); }
00309 Double_t GetEventLinearVal(UInt_t i) const { return fEventLinearVal[i]; }
00310 Double_t GetEventLinearValNorm(UInt_t i) const { return fEventLinearVal[i]*fLinNorm[i]; }
00311
00312 const std::vector<UInt_t> & GetEventRuleMap(UInt_t evtidx) const { return fRuleMap[evtidx]; }
00313 const TMVA::Event *GetRuleMapEvent(UInt_t evtidx) const { return (*fRuleMapEvents)[evtidx]; }
00314 Bool_t IsRuleMapOK() const { return fRuleMapOK; }
00315
00316
00317 void PrintRuleGen() const;
00318
00319
00320 void Print() const;
00321
00322
00323 void PrintRaw ( ostream& os ) const;
00324 void* AddXMLTo ( void* parent ) const;
00325
00326
00327 void ReadRaw ( istream& istr );
00328 void ReadFromXML( void* wghtnode );
00329
00330
00331 private:
00332
00333
00334 void DeleteRules() { for (UInt_t i=0; i<fRules.size(); i++) delete fRules[i]; fRules.clear(); }
00335
00336
00337 void Copy( RuleEnsemble const& other );
00338
00339
00340 void ResetCoefficients();
00341
00342
00343 void MakeRulesFromTree( const DecisionTree *dtree );
00344
00345
00346 void AddRule( const Node *node );
00347
00348
00349 Rule *MakeTheRule( const Node *node );
00350
00351
00352 ELearningModel fLearningModel;
00353 Double_t fImportanceCut;
00354 Double_t fLinQuantile;
00355 Double_t fOffset;
00356 std::vector< TMVA::Rule* > fRules;
00357 std::vector< Char_t > fLinTermOK;
00358 std::vector< Double_t > fLinDP;
00359 std::vector< Double_t > fLinDM;
00360 std::vector< Double_t > fLinCoefficients;
00361 std::vector< Double_t > fLinNorm;
00362 std::vector< TH1F* > fLinPDFB;
00363 std::vector< TH1F* > fLinPDFS;
00364 std::vector< Double_t > fLinImportance;
00365 std::vector< Double_t > fVarImportance;
00366 Double_t fImportanceRef;
00367 Double_t fAverageSupport;
00368 Double_t fAverageRuleSigma;
00369
00370 std::vector< Double_t > fRuleVarFrac;
00371 std::vector< Double_t > fRulePSS;
00372 std::vector< Double_t > fRulePSB;
00373 std::vector< Double_t > fRulePBS;
00374 std::vector< Double_t > fRulePBB;
00375 std::vector< Double_t > fRulePTag;
00376 Double_t fRuleFSig;
00377 Double_t fRuleNCave;
00378 Double_t fRuleNCsig;
00379
00380 Double_t fRuleMinDist;
00381 UInt_t fNRulesGenerated;
00382
00383 const Event* fEvent;
00384 Bool_t fEventCacheOK;
00385 std::vector<Char_t> fEventRuleVal;
00386 std::vector<Double_t> fEventLinearVal;
00387
00388 Bool_t fRuleMapOK;
00389 std::vector< std::vector<UInt_t> > fRuleMap;
00390 UInt_t fRuleMapInd0;
00391 UInt_t fRuleMapInd1;
00392 const std::vector<TMVA::Event *> *fRuleMapEvents;
00393
00394 const RuleFit* fRuleFit;
00395
00396 mutable MsgLogger* fLogger;
00397 MsgLogger& Log() const { return *fLogger; }
00398 };
00399 }
00400
00401
00402 inline void TMVA::RuleEnsemble::UpdateEventVal()
00403 {
00404
00405
00406
00407 if (fEventCacheOK) return;
00408
00409 if (DoRules()) {
00410 UInt_t nrules = fRules.size();
00411 fEventRuleVal.resize(nrules,kFALSE);
00412 for (UInt_t r=0; r<nrules; r++) {
00413 fEventRuleVal[r] = fRules[r]->EvalEvent(*fEvent);
00414 }
00415 }
00416 if (DoLinear()) {
00417 UInt_t nlin = fLinTermOK.size();
00418 fEventLinearVal.resize(nlin,0);
00419 for (UInt_t r=0; r<nlin; r++) {
00420 fEventLinearVal[r] = EvalLinEventRaw(r,*fEvent,kFALSE);
00421 }
00422 }
00423 fEventCacheOK = kTRUE;
00424 }
00425
00426
00427 inline Double_t TMVA::RuleEnsemble::EvalEvent() const
00428 {
00429
00430
00431 Int_t nrules = fRules.size();
00432 Double_t rval=fOffset;
00433 Double_t linear=0;
00434
00435
00436
00437
00438 if (DoRules()) {
00439 for ( Int_t i=0; i<nrules; i++ ) {
00440 if (fEventRuleVal[i])
00441 rval += fRules[i]->GetCoefficient();
00442 }
00443 }
00444
00445
00446
00447 if (DoLinear()) linear = EvalLinEvent();
00448 rval +=linear;
00449
00450 return rval;
00451 }
00452
00453
00454 inline Double_t TMVA::RuleEnsemble::EvalEvent( Double_t ofs,
00455 const std::vector<Double_t> & coefs,
00456 const std::vector<Double_t> & lincoefs ) const
00457 {
00458
00459
00460 Int_t nrules = fRules.size();
00461 Double_t rval = ofs;
00462 Double_t linear = 0;
00463
00464
00465
00466 if (DoRules()) {
00467 for ( Int_t i=0; i<nrules; i++ ) {
00468 if (fEventRuleVal[i])
00469 rval += coefs[i];
00470 }
00471 }
00472
00473
00474
00475 if (DoLinear()) linear = EvalLinEvent(lincoefs);
00476 rval +=linear;
00477
00478 return rval;
00479 }
00480
00481
00482 inline Double_t TMVA::RuleEnsemble::EvalEvent(const TMVA::Event & e)
00483 {
00484
00485 SetEvent(e);
00486 UpdateEventVal();
00487 return EvalEvent();
00488 }
00489
00490
00491 inline Double_t TMVA::RuleEnsemble::EvalEvent(const TMVA::Event & e,
00492 Double_t ofs,
00493 const std::vector<Double_t> & coefs,
00494 const std::vector<Double_t> & lincoefs )
00495 {
00496
00497 SetEvent(e);
00498 UpdateEventVal();
00499 return EvalEvent(ofs,coefs,lincoefs);
00500 }
00501
00502
00503 inline Double_t TMVA::RuleEnsemble::EvalEvent(UInt_t evtidx) const
00504 {
00505
00506 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
00507
00508 Double_t rval=fOffset;
00509 if (DoRules()) {
00510 UInt_t nrules = fRuleMap[evtidx].size();
00511 UInt_t rind;
00512 for (UInt_t ir = 0; ir<nrules; ir++) {
00513 rind = fRuleMap[evtidx][ir];
00514 rval += fRules[rind]->GetCoefficient();
00515 }
00516 }
00517 if (DoLinear()) {
00518 UInt_t nlin = fLinTermOK.size();
00519 for (UInt_t r=0; r<nlin; r++) {
00520 if (fLinTermOK[r]) {
00521 rval += fLinCoefficients[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
00522 }
00523 }
00524 }
00525 return rval;
00526 }
00527
00528
00529 inline Double_t TMVA::RuleEnsemble::EvalEvent(UInt_t evtidx,
00530 Double_t ofs,
00531 const std::vector<Double_t> & coefs,
00532 const std::vector<Double_t> & lincoefs ) const
00533 {
00534
00535
00536 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
00537 Double_t rval=ofs;
00538 if (DoRules()) {
00539 UInt_t nrules = fRuleMap[evtidx].size();
00540 UInt_t rind;
00541 for (UInt_t ir = 0; ir<nrules; ir++) {
00542 rind = fRuleMap[evtidx][ir];
00543 rval += coefs[rind];
00544 }
00545 }
00546 if (DoLinear()) {
00547 rval += EvalLinEvent( evtidx, lincoefs );
00548 }
00549 return rval;
00550 }
00551
00552
00553 inline Double_t TMVA::RuleEnsemble::EvalLinEventRaw( UInt_t vind, const TMVA::Event & e, Bool_t norm) const
00554 {
00555
00556
00557 Double_t val = e.GetValue(vind);
00558 Double_t rval = TMath::Min( fLinDP[vind], TMath::Max( fLinDM[vind], val ) );
00559 if (norm) rval *= fLinNorm[vind];
00560 return rval;
00561 }
00562
00563
00564 inline Double_t TMVA::RuleEnsemble::EvalLinEventRaw( UInt_t vind, UInt_t evtidx, Bool_t norm) const
00565 {
00566
00567
00568 Double_t val = (*fRuleMapEvents)[evtidx]->GetValue(vind);
00569 Double_t rval = TMath::Min( fLinDP[vind], TMath::Max( fLinDM[vind], val ) );
00570 if (norm) rval *= fLinNorm[vind];
00571 return rval;
00572 }
00573
00574
00575 inline Double_t TMVA::RuleEnsemble::EvalLinEvent() const
00576 {
00577
00578
00579 Double_t rval=0;
00580 for (UInt_t v=0; v<fLinTermOK.size(); v++) {
00581 if (fLinTermOK[v])
00582 rval += fLinCoefficients[v]*fEventLinearVal[v]*fLinNorm[v];
00583 }
00584 return rval;
00585 }
00586
00587
00588 inline Double_t TMVA::RuleEnsemble::EvalLinEvent(const std::vector<Double_t> & coefs) const
00589 {
00590
00591
00592 Double_t rval=0;
00593 for (UInt_t v=0; v<fLinTermOK.size(); v++) {
00594 if (fLinTermOK[v])
00595 rval += coefs[v]*fEventLinearVal[v]*fLinNorm[v];
00596 }
00597 return rval;
00598 }
00599
00600
00601 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( const TMVA::Event& e )
00602 {
00603
00604
00605 SetEvent(e);
00606 UpdateEventVal();
00607 return EvalLinEvent();
00608 }
00609
00610
00611 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( const TMVA::Event& e, UInt_t vind )
00612 {
00613
00614
00615 SetEvent(e);
00616 UpdateEventVal();
00617 return GetEventLinearValNorm(vind);
00618 }
00619
00620
00621 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( const TMVA::Event& e, const std::vector<Double_t> & coefs )
00622 {
00623
00624
00625 SetEvent(e);
00626 UpdateEventVal();
00627 return EvalLinEvent(coefs);
00628 }
00629
00630
00631 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, const std::vector<Double_t> & coefs ) const
00632 {
00633
00634 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
00635 Double_t rval=0;
00636 UInt_t nlin = fLinTermOK.size();
00637 for (UInt_t r=0; r<nlin; r++) {
00638 if (fLinTermOK[r]) {
00639 rval += coefs[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
00640 }
00641 }
00642 return rval;
00643 }
00644
00645
00646 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx ) const
00647 {
00648
00649 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
00650 Double_t rval=0;
00651 UInt_t nlin = fLinTermOK.size();
00652 for (UInt_t r=0; r<nlin; r++) {
00653 if (fLinTermOK[r]) {
00654 rval += fLinCoefficients[r] * EvalLinEventRaw(r,*(*fRuleMapEvents)[evtidx],kTRUE);
00655 }
00656 }
00657 return rval;
00658 }
00659
00660
00661 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, UInt_t vind ) const
00662 {
00663
00664 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
00665 Double_t rval;
00666 rval = fLinCoefficients[vind] * EvalLinEventRaw(vind,*(*fRuleMapEvents)[evtidx],kTRUE);
00667 return rval;
00668 }
00669
00670
00671 inline Double_t TMVA::RuleEnsemble::EvalLinEvent( UInt_t evtidx, UInt_t vind, Double_t coefs ) const
00672 {
00673
00674 if ((evtidx<fRuleMapInd0) || (evtidx>fRuleMapInd1)) return 0;
00675 Double_t rval;
00676 rval = coefs * EvalLinEventRaw(vind,*(*fRuleMapEvents)[evtidx],kTRUE);
00677 return rval;
00678 }
00679
00680 #endif