RuleEnsemble.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: RuleEnsemble.cxx 36261 2010-10-10 21:12:08Z stelzer $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : RuleEnsemble                                                          *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      A class generating an ensemble of rules                                   *
00012  *      Input:  a forest of decision trees                                        *
00013  *      Output: an ensemble of rules                                              *
00014  *                                                                                *
00015  * Authors (alphabetical):                                                        *
00016  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
00017  *      Helge Voss         <Helge.Voss@cern.ch>         - MPI-KP Heidelberg, GER  *
00018  *                                                                                *
00019  * Copyright (c) 2005:                                                            *
00020  *      CERN, Switzerland                                                         *
00021  *      Iowa State U.                                                             *
00022  *      MPI-K Heidelberg, Germany                                                 *
00023  *                                                                                *
00024  * Redistribution and use in source and binary forms, with or without             *
00025  * modification, are permitted according to the terms listed in LICENSE           *
00026  * (http://tmva.sourceforge.net/LICENSE)                                          *
00027  **********************************************************************************/
00028 
00029 #include <algorithm>
00030 #include <list>
00031 #include <cstdlib>
00032 #include <iomanip>
00033 
00034 #include "TRandom3.h"
00035 #include "TH1F.h"
00036 #include "TMVA/RuleEnsemble.h"
00037 #include "TMVA/RuleFit.h"
00038 #include "TMVA/MethodRuleFit.h"
00039 #include "TMVA/Tools.h"
00040 
00041 //_______________________________________________________________________
00042 TMVA::RuleEnsemble::RuleEnsemble( RuleFit *rf )
00043    : fLearningModel   ( kFull )
00044    , fImportanceCut   ( 0 )
00045    , fLinQuantile     ( 0.025 ) // default quantile for killing outliers in linear terms
00046    , fOffset          ( 0 )
00047    , fAverageSupport  ( 0.8 )
00048    , fAverageRuleSigma( 0.4 )  // default value - used if only linear model is chosen
00049    , fRuleFSig        ( 0 )
00050    , fRuleNCave       ( 0 )
00051    , fRuleNCsig       ( 0 )
00052    , fRuleMinDist     ( 1e-3 ) // closest allowed 'distance' between two rules
00053    , fNRulesGenerated ( 0 )
00054    , fEvent           ( 0 )
00055    , fEventCacheOK    ( true )
00056    , fRuleMapOK       ( true )
00057    , fRuleMapInd0     ( 0 )
00058    , fRuleMapInd1     ( 0 )
00059    , fRuleMapEvents   ( 0 )
00060    , fLogger( new MsgLogger("RuleFit") )
00061 {
00062    // constructor
00063    Initialize( rf );
00064 }
00065 
00066 //_______________________________________________________________________
00067 TMVA::RuleEnsemble::RuleEnsemble( const RuleEnsemble& other )
00068    : fAverageSupport   ( 1 )
00069    , fEvent(0)
00070    , fRuleMapEvents(0)
00071    , fRuleFit(0)
00072    , fLogger( new MsgLogger("RuleFit") )
00073 {
00074    // copy constructor
00075    Copy( other );
00076 }
00077 
00078 //_______________________________________________________________________
00079 TMVA::RuleEnsemble::RuleEnsemble()
00080    : fLearningModel     ( kFull )
00081    , fImportanceCut   ( 0 )
00082    , fLinQuantile     ( 0.025 ) // default quantile for killing outliers in linear terms
00083    , fOffset          ( 0 )
00084    , fImportanceRef   ( 1.0 )
00085    , fAverageSupport  ( 0.8 )
00086    , fAverageRuleSigma( 0.4 )  // default value - used if only linear model is chosen
00087    , fRuleFSig        ( 0 )
00088    , fRuleNCave       ( 0 )
00089    , fRuleNCsig       ( 0 )
00090    , fRuleMinDist     ( 1e-3 ) // closest allowed 'distance' between two rules
00091    , fNRulesGenerated ( 0 )
00092    , fEvent           ( 0 )
00093    , fEventCacheOK    ( true )
00094    , fRuleMapOK       ( true )
00095    , fRuleMapInd0     ( 0 )
00096    , fRuleMapInd1     ( 0 )
00097    , fRuleMapEvents   ( 0 )
00098    , fRuleFit         ( 0 )
00099    , fLogger( new MsgLogger("RuleFit") )
00100 {
00101    // constructor
00102 }
00103 
00104 //_______________________________________________________________________
00105 TMVA::RuleEnsemble::~RuleEnsemble()
00106 {
00107    // destructor
00108    for ( std::vector<Rule *>::iterator itrRule = fRules.begin(); itrRule != fRules.end(); itrRule++ ) {
00109       delete *itrRule;
00110    }
00111    // NOTE: Should not delete the histos fLinPDFB/S since they are delete elsewhere
00112    delete fLogger;
00113 }
00114 
00115 //_______________________________________________________________________
00116 void TMVA::RuleEnsemble::Initialize( const RuleFit *rf )
00117 {
00118    // Initializes all member variables with default values
00119 
00120    SetAverageRuleSigma(0.4); // default value - used if only linear model is chosen
00121    fRuleFit = rf;
00122    UInt_t nvars =  GetMethodBase()->GetNvar();
00123    fVarImportance.clear();
00124    fLinPDFB.clear();
00125    fLinPDFS.clear();
00126    //
00127    fVarImportance.resize( nvars,0.0 );
00128    fLinPDFB.resize( nvars,0 );
00129    fLinPDFS.resize( nvars,0 );
00130    fImportanceRef = 1.0;
00131    for (UInt_t i=0; i<nvars; i++) { // a priori all linear terms are equally valid
00132       fLinTermOK.push_back(kTRUE);
00133    }
00134 }
00135 
00136 //_______________________________________________________________________
00137 void TMVA::RuleEnsemble::SetMsgType( EMsgType t ) {
00138    fLogger->SetMinType(t);
00139 }
00140 
00141 
00142 //_______________________________________________________________________
00143 const TMVA::MethodRuleFit*  TMVA::RuleEnsemble::GetMethodRuleFit() const
00144 {
00145    //
00146    // Get a pointer to the original MethodRuleFit.
00147    //
00148    return ( fRuleFit==0 ? 0:fRuleFit->GetMethodRuleFit());
00149 }
00150 
00151 //_______________________________________________________________________
00152 const TMVA::MethodBase*  TMVA::RuleEnsemble::GetMethodBase() const
00153 {
00154    //
00155    // Get a pointer to the original MethodRuleFit.
00156    //
00157    return ( fRuleFit==0 ? 0:fRuleFit->GetMethodBase());
00158 }
00159 
00160 //_______________________________________________________________________
00161 void TMVA::RuleEnsemble::MakeModel()
00162 {
00163    // create model
00164    MakeRules( fRuleFit->GetForest() );
00165    
00166    MakeLinearTerms();
00167 
00168    MakeRuleMap();
00169 
00170    CalcRuleSupport();
00171 
00172    RuleStatistics();
00173 
00174    PrintRuleGen();
00175 }
00176 
00177 //_______________________________________________________________________
00178 Double_t TMVA::RuleEnsemble::CoefficientRadius()
00179 {
00180    //
00181    // Calculates sqrt(Sum(a_i^2)), i=1..N (NOTE do not include a0)
00182    //
00183    Int_t ncoeffs = fRules.size();
00184    if (ncoeffs<1) return 0;
00185    //
00186    Double_t sum2=0;
00187    Double_t val;
00188    for (Int_t i=0; i<ncoeffs; i++) {
00189       val = fRules[i]->GetCoefficient();
00190       sum2 += val*val;
00191    }
00192    return sum2;
00193 }
00194 
00195 //_______________________________________________________________________
00196 void TMVA::RuleEnsemble::ResetCoefficients()
00197 {
00198    // reset all rule coefficients
00199 
00200    fOffset = 0.0;
00201    UInt_t nrules = fRules.size();
00202    for (UInt_t i=0; i<nrules; i++) {
00203       fRules[i]->SetCoefficient(0.0);
00204    }
00205 }
00206 
00207 //_______________________________________________________________________
00208 void TMVA::RuleEnsemble::SetCoefficients( const std::vector< Double_t > & v )
00209 {
00210    // set all rule coefficients
00211 
00212    UInt_t nrules = fRules.size();
00213    if (v.size()!=nrules) {
00214       Log() << kFATAL << "<SetCoefficients> - BUG TRAP - input vector worng size! It is = " << v.size()
00215             << " when it should be = " << nrules << Endl;
00216    }
00217    for (UInt_t i=0; i<nrules; i++) {
00218       fRules[i]->SetCoefficient(v[i]);
00219    }
00220 }
00221 
00222 //_______________________________________________________________________
00223 void TMVA::RuleEnsemble::GetCoefficients( std::vector< Double_t > & v )
00224 {
00225    // Retrieve all rule coefficients
00226 
00227    UInt_t nrules = fRules.size();
00228    v.resize(nrules);
00229    if (nrules==0) return;
00230    //
00231    for (UInt_t i=0; i<nrules; i++) {
00232       v[i] = (fRules[i]->GetCoefficient());
00233    }
00234 }
00235 
00236 //_______________________________________________________________________
00237 const std::vector<TMVA::Event*>* TMVA::RuleEnsemble::GetTrainingEvents()  const
00238 { 
00239    // get list of training events from the rule fitter
00240 
00241    return &(fRuleFit->GetTrainingEvents());
00242 }
00243 
00244 //_______________________________________________________________________
00245 const TMVA::Event * TMVA::RuleEnsemble::GetTrainingEvent(UInt_t i) const
00246 {
00247    // get the training event from the rule fitter
00248    return fRuleFit->GetTrainingEvent(i);
00249 }
00250 
00251 //_______________________________________________________________________
00252 void TMVA::RuleEnsemble::RemoveSimilarRules()
00253 {
00254    // remove rules that behave similar 
00255    
00256    Log() << kVERBOSE << "Removing similar rules; distance = " << fRuleMinDist << Endl;
00257 
00258    UInt_t nrulesIn = fRules.size();
00259    TMVA::Rule *first, *second;
00260    std::vector< Char_t > removeMe( nrulesIn,false );  // <--- stores boolean
00261 
00262    Int_t nrem = 0;
00263    Int_t remind=-1;
00264    Double_t r;
00265 
00266    for (UInt_t i=0; i<nrulesIn; i++) {
00267       if (!removeMe[i]) {
00268          first = fRules[i];
00269          for (UInt_t k=i+1; k<nrulesIn; k++) {
00270             if (!removeMe[k]) {
00271                second = fRules[k];
00272                Bool_t equal = first->Equal(*second,kTRUE,fRuleMinDist);
00273                if (equal) {
00274                   r = gRandom->Rndm();
00275                   remind = (r>0.5 ? k:i); // randomly select rule
00276                } 
00277                else {
00278                   remind = -1;
00279                }
00280 
00281                if (remind>-1) {
00282                   if (!removeMe[remind]) {
00283                      removeMe[remind] = true;
00284                      nrem++;
00285                   }
00286                }
00287             }
00288          }
00289       }
00290    }
00291    UInt_t ind = 0;
00292    Rule *theRule;
00293    for (UInt_t i=0; i<nrulesIn; i++) {
00294       if (removeMe[i]) {
00295          theRule = fRules[ind];
00296 #if _MSC_VER >= 1400
00297          fRules.erase( std::vector<Rule *>::iterator(&fRules[ind], &fRules) );
00298 #else
00299          fRules.erase( std::vector<Rule *>::iterator(&fRules[ind]) );
00300 #endif
00301          delete theRule;
00302          ind--;
00303       } 
00304       ind++;
00305    }
00306    UInt_t nrulesOut = fRules.size();
00307    Log() << kVERBOSE << "Removed " << nrulesIn - nrulesOut << " out of " << nrulesIn << " rules" << Endl;
00308 }
00309 
00310 //_______________________________________________________________________
00311 void TMVA::RuleEnsemble::CleanupRules()
00312 {
00313    // cleanup rules
00314 
00315    UInt_t nrules   = fRules.size();
00316    if (nrules==0) return;
00317    Log() << kVERBOSE << "Removing rules with relative importance < " << fImportanceCut << Endl;
00318    if (fImportanceCut<=0) return;
00319    //
00320    // Mark rules to be removed
00321    //
00322    Rule *therule;
00323    Int_t ind=0;
00324    for (UInt_t i=0; i<nrules; i++) {
00325       if (fRules[ind]->GetRelImportance()<fImportanceCut) {
00326          therule = fRules[ind];
00327 #if _MSC_VER >= 1400
00328          fRules.erase( std::vector<Rule *>::iterator(&fRules[ind], &fRules) );
00329 #else
00330          fRules.erase( std::vector<Rule *>::iterator(&fRules[ind]) );
00331 #endif
00332          delete therule;
00333          ind--;
00334       } 
00335       ind++;
00336    }
00337    Log() << kINFO << "Removed " << nrules-ind << " out of a total of " << nrules
00338          << " rules with importance < " << fImportanceCut << Endl;
00339 }
00340 
00341 //_______________________________________________________________________
00342 void TMVA::RuleEnsemble::CleanupLinear()
00343 {
00344    // cleanup linear model
00345 
00346    UInt_t nlin = fLinNorm.size();
00347    if (nlin==0) return;
00348    Log() << kVERBOSE << "Removing linear terms with relative importance < " << fImportanceCut << Endl;
00349    //
00350    fLinTermOK.clear();
00351    for (UInt_t i=0; i<nlin; i++) {
00352       fLinTermOK.push_back( (fLinImportance[i]/fImportanceRef > fImportanceCut) );
00353    }
00354 }
00355 
00356 //_______________________________________________________________________
00357 void TMVA::RuleEnsemble::CalcRuleSupport()
00358 {
00359    // calculate the support for all rules
00360    Log() << kVERBOSE << "Evaluating Rule support" << Endl;
00361    Double_t s,t,stot,ttot,ssb;
00362    Double_t ssig, sbkg, ssum;
00363    Int_t indrule=0;
00364    stot = 0;
00365    ttot = 0;
00366    // reset to default values
00367    SetAverageRuleSigma(0.4);
00368    const std::vector<Event *> *events = GetTrainingEvents();
00369    Double_t nrules = static_cast<Double_t>(fRules.size());
00370    Double_t ew;
00371    //
00372    if ((nrules>0) && (events->size()>0)) {
00373       for ( std::vector< Rule * >::iterator itrRule=fRules.begin(); itrRule!=fRules.end(); itrRule++ ) {
00374          s=0.0;
00375          ssig=0.0;
00376          sbkg=0.0;
00377          for ( std::vector<Event * >::const_iterator itrEvent=events->begin(); itrEvent!=events->end(); itrEvent++ ) {
00378             if ((*itrRule)->EvalEvent( *(*itrEvent) )) {
00379                ew = (*itrEvent)->GetWeight();
00380                s += ew;
00381                if (GetMethodRuleFit()->DataInfo().IsSignal(*itrEvent)) ssig += ew;
00382                else                         sbkg += ew;
00383             }
00384          }
00385          //
00386          s = s/fRuleFit->GetNEveEff();
00387          t = s*(1.0-s);
00388          t = (t<0 ? 0:sqrt(t));
00389          stot += s;
00390          ttot += t;
00391          ssum = ssig+sbkg;
00392          ssb = (ssum>0 ? Double_t(ssig)/Double_t(ssig+sbkg) : 0.0 );
00393          (*itrRule)->SetSupport(s);
00394          (*itrRule)->SetNorm(t);
00395          (*itrRule)->SetSSB( ssb );
00396          (*itrRule)->SetSSBNeve(Double_t(ssig+sbkg));
00397          indrule++;
00398       }
00399       fAverageSupport   = stot/nrules;
00400       fAverageRuleSigma = TMath::Sqrt(fAverageSupport*(1.0-fAverageSupport));
00401       Log() << kVERBOSE << "Standard deviation of support = " << fAverageRuleSigma << Endl;
00402       Log() << kVERBOSE << "Average rule support          = " << fAverageSupport   << Endl;
00403    }
00404 }
00405 
00406 //_______________________________________________________________________
00407 void TMVA::RuleEnsemble::CalcImportance()
00408 {
00409    // calculate the importance of each rule
00410 
00411    Double_t maxRuleImp = CalcRuleImportance();
00412    Double_t maxLinImp  = CalcLinImportance();
00413    Double_t maxImp = (maxRuleImp>maxLinImp ? maxRuleImp : maxLinImp);
00414    SetImportanceRef( maxImp );
00415 }
00416 
00417 //_______________________________________________________________________
00418 void TMVA::RuleEnsemble::SetImportanceRef(Double_t impref)
00419 {
00420    // set reference importance
00421    for ( UInt_t i=0; i<fRules.size(); i++ ) {
00422       fRules[i]->SetImportanceRef(impref);
00423    }
00424    fImportanceRef = impref;
00425 }
00426 //_______________________________________________________________________
00427 Double_t TMVA::RuleEnsemble::CalcRuleImportance()
00428 {
00429    // calculate importance of each rule
00430 
00431    Double_t maxImp=-1.0;
00432    Double_t imp;
00433    Int_t nrules = fRules.size();
00434    for ( int i=0; i<nrules; i++ ) {
00435       fRules[i]->CalcImportance();
00436       imp = fRules[i]->GetImportance();
00437       if (imp>maxImp) maxImp = imp;
00438    }
00439    for ( Int_t i=0; i<nrules; i++ ) {
00440       fRules[i]->SetImportanceRef(maxImp);
00441    }
00442 
00443    return maxImp;
00444 }
00445 
00446 //_______________________________________________________________________
00447 Double_t TMVA::RuleEnsemble::CalcLinImportance()
00448 {
00449    // calculate the linear importance for each rule
00450 
00451    Double_t maxImp=-1.0;
00452    UInt_t nvars = fLinCoefficients.size();
00453    fLinImportance.resize(nvars,0.0);
00454    if (!DoLinear()) return maxImp;
00455    //
00456    // The linear importance is:
00457    // I = |b_x|*sigma(x)
00458    // Note that the coefficients in fLinCoefficients are for the normalized x
00459    // => b'_x * x' = b'_x * sigma(r)*x/sigma(x)
00460    // => b_x = b'_x*sigma(r)/sigma(x)
00461    // => I = |b'_x|*sigma(r)
00462    //
00463    Double_t imp;
00464    for ( UInt_t i=0; i<nvars; i++ ) {
00465       imp = fAverageRuleSigma*TMath::Abs(fLinCoefficients[i]);
00466       fLinImportance[i] = imp;
00467       if (imp>maxImp) maxImp = imp;
00468    }
00469    return maxImp;
00470 }
00471 
00472 //_______________________________________________________________________
00473 void TMVA::RuleEnsemble::CalcVarImportance()
00474 {
00475    //
00476    // Calculates variable importance using eq (35) in RuleFit paper by Friedman et.al
00477    //
00478    Log() << kVERBOSE << "Compute variable importance" << Endl;
00479    Double_t rimp;
00480    UInt_t nrules = fRules.size();
00481    if (GetMethodBase()==0) Log() << kFATAL << "RuleEnsemble::CalcVarImportance() - should not be here!" << Endl;
00482    UInt_t nvars  = GetMethodBase()->GetNvar();
00483    UInt_t nvarsUsed;
00484    Double_t rimpN;
00485    fVarImportance.resize(nvars,0);
00486    // rules
00487    if (DoRules()) {
00488       for ( UInt_t ind=0; ind<nrules; ind++ ) {
00489          rimp = fRules[ind]->GetImportance();
00490          nvarsUsed = fRules[ind]->GetNumVarsUsed();
00491          if (nvarsUsed<1)
00492             Log() << kFATAL << "<CalcVarImportance> Variables for importance calc!!!??? A BUG!" << Endl;
00493          rimpN = (nvarsUsed > 0 ? rimp/nvarsUsed:0.0);
00494          for ( UInt_t iv=0; iv<nvars; iv++ ) {
00495             if (fRules[ind]->ContainsVariable(iv)) {
00496                fVarImportance[iv] += rimpN;
00497             }
00498          }
00499       }
00500    }
00501    // linear terms
00502    if (DoLinear()) {
00503       for ( UInt_t iv=0; iv<fLinTermOK.size(); iv++ ) {
00504          if (fLinTermOK[iv]) fVarImportance[iv] += fLinImportance[iv];
00505       }
00506    }
00507    //
00508    // Make variable importance relative the strongest variable
00509    //
00510    Double_t maximp = 0.0;
00511    for ( UInt_t iv=0; iv<nvars; iv++ ) {
00512       if ( fVarImportance[iv] > maximp ) maximp = fVarImportance[iv];
00513    }
00514    if (maximp>0) {
00515       for ( UInt_t iv=0; iv<nvars; iv++ ) {
00516          fVarImportance[iv] *= 1.0/maximp;
00517       }
00518    }
00519 }
00520 
00521 //_______________________________________________________________________
00522 void TMVA::RuleEnsemble::SetRules( const std::vector<Rule *> & rules )
00523 {
00524    // set rules
00525    //
00526    // first clear all
00527    DeleteRules();
00528    //
00529    fRules.resize(rules.size());
00530    for (UInt_t i=0; i<fRules.size(); i++) {
00531       fRules[i] = rules[i];
00532    }
00533    fEventCacheOK = kFALSE;
00534 }
00535 
00536 //_______________________________________________________________________
00537 void TMVA::RuleEnsemble::MakeRules( const std::vector< const DecisionTree *> & forest )
00538 {
00539    //
00540    // Makes rules from the given decision tree.
00541    // First node in all rules is ALWAYS the root node.
00542    //
00543    fRules.clear();
00544    if (!DoRules()) return;
00545    //
00546    Int_t nrulesCheck=0;
00547    Int_t nrules;
00548    Int_t nendn;
00549    Double_t sumnendn=0;
00550    Double_t sumn2=0;
00551    //
00552    UInt_t prevs;
00553    UInt_t ntrees = forest.size();
00554    for ( UInt_t ind=0; ind<ntrees; ind++ ) {
00555       prevs = fRules.size();
00556       MakeRulesFromTree( forest[ind] );
00557       nrules = CalcNRules( forest[ind] );
00558       nendn = (nrules/2) + 1;
00559       sumnendn += nendn;
00560       sumn2    += nendn*nendn;
00561       nrulesCheck += nrules;
00562    }
00563    Double_t nmean = sumnendn/ntrees;
00564    Double_t nsigm = TMath::Sqrt( gTools().ComputeVariance(sumn2,sumnendn,ntrees) );
00565    Double_t ndev = 2.0*(nmean-2.0-nsigm)/(nmean-2.0+nsigm);
00566    //
00567    Log() << kVERBOSE << "Average number of end nodes per tree   = " << nmean << Endl;
00568    if (ntrees>1) Log() << kVERBOSE << "sigma of ditto ( ~= mean-2 ?)          = "
00569                        << nsigm
00570                        << Endl;
00571    Log() << kVERBOSE << "Deviation from exponential model       = " << ndev      << Endl;
00572    Log() << kVERBOSE << "Corresponds to L (eq. 13, RuleFit ppr) = " << nmean << Endl;
00573    // a BUG trap
00574    if (nrulesCheck != static_cast<Int_t>(fRules.size())) {
00575       Log() << kFATAL 
00576             << "BUG! number of generated and possible rules do not match! N(rules) =  " << fRules.size() 
00577             << " != " << nrulesCheck << Endl;
00578    }
00579    Log() << kVERBOSE << "Number of generated rules: " << fRules.size() << Endl;
00580 
00581    // save initial number of rules
00582    fNRulesGenerated = fRules.size();
00583 
00584    RemoveSimilarRules();
00585 
00586    ResetCoefficients();
00587 
00588 }
00589 
00590 //_______________________________________________________________________
00591 void TMVA::RuleEnsemble::MakeLinearTerms()
00592 {
00593    //
00594    // Make the linear terms as in eq 25, ref 2
00595    // For this the b and (1-b) quatiles are needed
00596    //
00597    if (!DoLinear()) return;
00598 
00599    const std::vector<Event *> *events = GetTrainingEvents();
00600    UInt_t neve  = events->size();
00601    UInt_t nvars = ((*events)[0])->GetNVariables(); // Event -> GetNVariables();
00602    Double_t val,ew;
00603    typedef std::pair< Double_t, Int_t> dataType;
00604    typedef std::pair< Double_t, dataType > dataPoint;
00605 
00606    std::vector< std::vector<dataPoint> > vardata(nvars);
00607    std::vector< Double_t > varsum(nvars,0.0);
00608    std::vector< Double_t > varsum2(nvars,0.0);
00609    // first find stats of all variables
00610    // vardata[v][i].first         -> value of var <v> in event <i>
00611    // vardata[v][i].second.first  -> the event weight
00612    // vardata[v][i].second.second -> the event type
00613    for (UInt_t i=0; i<neve; i++) {
00614       ew   = ((*events)[i])->GetWeight();
00615       for (UInt_t v=0; v<nvars; v++) {
00616          val = ((*events)[i])->GetValue(v);
00617          vardata[v].push_back( dataPoint( val, dataType(ew,((*events)[i])->GetClass()) ) );
00618       }
00619    }
00620    //
00621    fLinDP.clear();
00622    fLinDM.clear();
00623    fLinCoefficients.clear();
00624    fLinNorm.clear();
00625    fLinDP.resize(nvars,0);
00626    fLinDM.resize(nvars,0);
00627    fLinCoefficients.resize(nvars,0);
00628    fLinNorm.resize(nvars,0);
00629 
00630    Double_t averageWeight = fRuleFit->GetNEveEff()/static_cast<Double_t>(neve);
00631    // sort and find limits
00632    Double_t stdl;
00633 
00634    // find normalisation given in ref 2 after eq 26
00635    Double_t lx;
00636    Double_t nquant;
00637    Double_t neff;
00638    UInt_t   indquantM;
00639    UInt_t   indquantP;
00640    
00641    for (UInt_t v=0; v<nvars; v++) {
00642       varsum[v] = 0;
00643       varsum2[v] = 0;
00644       //
00645       std::sort( vardata[v].begin(),vardata[v].end() );
00646       nquant = fLinQuantile*fRuleFit->GetNEveEff(); // quantile = 0.025
00647       neff=0;
00648       UInt_t ie=0;
00649       // first scan for lower quantile (including weights)
00650       while ( (ie<neve) && (neff<nquant) ) {
00651          neff += vardata[v][ie].second.first;
00652          ie++;
00653       }
00654       indquantM = (ie==0 ? 0:ie-1);
00655       // now for upper quantile
00656       ie = neve;
00657       neff=0;
00658       while ( (ie>0) && (neff<nquant) ) {
00659          ie--;
00660          neff += vardata[v][ie].second.first;
00661       }
00662       indquantP = (ie==neve ? ie=neve-1:ie);
00663       //
00664       fLinDM[v] = vardata[v][indquantM].first; // delta-
00665       fLinDP[v] = vardata[v][indquantP].first; // delta+
00666       if (fLinPDFB[v]) delete fLinPDFB[v];
00667       if (fLinPDFS[v]) delete fLinPDFS[v];
00668       fLinPDFB[v] = new TH1F(Form("bkgvar%d",v),"bkg temphist",40,fLinDM[v],fLinDP[v]);
00669       fLinPDFS[v] = new TH1F(Form("sigvar%d",v),"sig temphist",40,fLinDM[v],fLinDP[v]);
00670       fLinPDFB[v]->Sumw2();
00671       fLinPDFS[v]->Sumw2();
00672       //
00673       Int_t type;
00674       const Double_t w = 1.0/fRuleFit->GetNEveEff();
00675       for (ie=0; ie<neve; ie++) {
00676          val  = vardata[v][ie].first;
00677          ew   = vardata[v][ie].second.first;
00678          type = vardata[v][ie].second.second;
00679          lx = TMath::Min( fLinDP[v], TMath::Max( fLinDM[v], val ) );
00680          varsum[v] += ew*lx;
00681          varsum2[v] += ew*lx*lx;
00682          if (type==1) fLinPDFS[v]->Fill(lx,w*ew);
00683          else         fLinPDFB[v]->Fill(lx,w*ew);
00684       }
00685       //
00686       // Get normalization.
00687       //
00688       stdl = TMath::Sqrt( (varsum2[v] - (varsum[v]*varsum[v]/fRuleFit->GetNEveEff()))/(fRuleFit->GetNEveEff()-averageWeight) );
00689       fLinNorm[v] = CalcLinNorm(stdl);
00690    }
00691    // Save PDFs - for debugging purpose
00692    for (UInt_t v=0; v<nvars; v++) {
00693       fLinPDFS[v]->Write();
00694       fLinPDFB[v]->Write();
00695    }
00696 }
00697 
00698 
00699 //_______________________________________________________________________
00700 Double_t TMVA::RuleEnsemble::PdfLinear( Double_t & nsig, Double_t & ntot  ) const
00701 {
00702    //
00703    // This function returns Pr( y = 1 | x ) for the linear terms.
00704    //
00705    UInt_t nvars=fLinDP.size();
00706 
00707    Double_t fstot=0;
00708    Double_t fbtot=0;
00709    nsig = 0;
00710    ntot = nvars;
00711    for (UInt_t v=0; v<nvars; v++) {
00712       Double_t val = fEventLinearVal[v];
00713       Int_t bin = fLinPDFS[v]->FindBin(val);
00714       fstot += fLinPDFS[v]->GetBinContent(bin);
00715       fbtot += fLinPDFB[v]->GetBinContent(bin);
00716    }
00717    if (nvars<1) return 0;
00718    ntot = (fstot+fbtot)/Double_t(nvars);
00719    nsig = (fstot)/Double_t(nvars);
00720    return fstot/(fstot+fbtot);
00721 }
00722 
00723 //_______________________________________________________________________
00724 Double_t TMVA::RuleEnsemble::PdfRule( Double_t & nsig, Double_t & ntot  ) const
00725 {
00726    //
00727    // This function returns Pr( y = 1 | x ) for rules.
00728    // The probability returned is normalized against the number of rules which are actually passed
00729    //
00730    Double_t sump  = 0;
00731    Double_t sumok = 0;
00732    Double_t sumz  = 0;
00733    Double_t ssb;
00734    Double_t neve;
00735    //
00736    UInt_t nrules = fRules.size();
00737    for (UInt_t ir=0; ir<nrules; ir++) {
00738       if (fEventRuleVal[ir]>0) {
00739          ssb = fEventRuleVal[ir]*GetRulesConst(ir)->GetSSB(); // S/(S+B) is evaluated in CalcRuleSupport() using ALL training events
00740          neve = GetRulesConst(ir)->GetSSBNeve(); // number of events accepted by the rule
00741          sump  += ssb*neve; // number of signal events
00742          sumok += neve; // total number of events passed
00743       } 
00744       else sumz += 1.0; // all events
00745    }
00746 
00747    nsig = sump;
00748    ntot = sumok;
00749    //
00750    if (ntot>0) return nsig/ntot;
00751    return 0.0;
00752 }
00753 
00754 //_______________________________________________________________________
00755 Double_t TMVA::RuleEnsemble::FStar( const Event & e )
00756 {
00757    //
00758    // We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x)
00759    // F(x) = FL(x) + FR(x) , linear and rule part
00760    // 
00761    //
00762    SetEvent(e);
00763    UpdateEventVal();
00764    return FStar();
00765 }
00766 
00767 //_______________________________________________________________________
00768 Double_t TMVA::RuleEnsemble::FStar() const
00769 {
00770    //
00771    // We want to estimate F* = argmin Eyx( L(y,F(x) ), min wrt F(x)
00772    // F(x) = FL(x) + FR(x) , linear and rule part
00773    // 
00774    //
00775    Double_t p=0;
00776    Double_t nrs=0, nrt=0;
00777    Double_t nls=0, nlt=0;
00778    Double_t nt;
00779    Double_t pr=0;
00780    Double_t pl=0;
00781 
00782    // first calculate Pr(y=1|X) for rules and linear terms
00783    if (DoLinear()) pl = PdfLinear(nls, nlt);
00784    if (DoRules())  pr = PdfRule(nrs, nrt);
00785    // nr(l)t=0 or 1
00786    if ((nlt>0) && (nrt>0)) nt=2.0;
00787    else                    nt=1.0;
00788    p = (pl+pr)/nt;
00789    return 2.0*p-1.0;
00790 }
00791 
00792 //_______________________________________________________________________
00793 void TMVA::RuleEnsemble::RuleResponseStats()
00794 {
00795    // calculate various statistics for this rule
00796 
00797    // TODO: NOT YET UPDATED FOR WEIGHTS
00798    const std::vector<Event *> *events = GetTrainingEvents();
00799    const UInt_t neve   = events->size();
00800    const UInt_t nvars  = GetMethodBase()->GetNvar();
00801    const UInt_t nrules = fRules.size();
00802    const Event *eveData;
00803    // Flags
00804    Bool_t sigRule;
00805    Bool_t sigTag;
00806    Bool_t bkgTag;
00807    Bool_t noTag;
00808    Bool_t sigTrue;
00809    Bool_t tagged;
00810    // Counters
00811    Int_t nsig=0;
00812    Int_t nbkg=0;
00813    Int_t ntag=0;
00814    Int_t nss=0;
00815    Int_t nsb=0;
00816    Int_t nbb=0;
00817    Int_t nbs=0;
00818    std::vector<Int_t> varcnt;
00819    // Clear vectors
00820    fRulePSS.clear();
00821    fRulePSB.clear();
00822    fRulePBS.clear();
00823    fRulePBB.clear();
00824    fRulePTag.clear();
00825    //
00826    varcnt.resize(nvars,0);
00827    fRuleVarFrac.clear();
00828    fRuleVarFrac.resize(nvars,0);
00829    //
00830    for ( UInt_t i=0; i<nrules; i++ ) {
00831       for ( UInt_t v=0; v<nvars; v++) {
00832          if (fRules[i]->ContainsVariable(v)) varcnt[v]++; // count how often a variable occurs
00833       }
00834       sigRule = fRules[i]->IsSignalRule();
00835       if (sigRule) { // rule is a signal rule (ie s/(s+b)>0.5)
00836          nsig++;
00837       } 
00838       else {
00839          nbkg++;
00840       }
00841       // reset counters
00842       nss=0;
00843       nsb=0;
00844       nbs=0;
00845       nbb=0;
00846       ntag=0;
00847       // loop over all events
00848       for (UInt_t e=0; e<neve; e++) {
00849          eveData = (*events)[e];
00850          tagged  = fRules[i]->EvalEvent(*eveData);
00851          sigTag = (tagged && sigRule);        // it's tagged as a signal
00852          bkgTag = (tagged && (!sigRule));     // ... as bkg
00853          noTag = !(sigTag || bkgTag);         // ... not tagged
00854          sigTrue = (eveData->GetClass() == 0);       // true if event is true signal
00855          if (tagged) {
00856             ntag++;
00857             if (sigTag && sigTrue)  nss++;
00858             if (sigTag && !sigTrue) nsb++;
00859             if (bkgTag && sigTrue)  nbs++;
00860             if (bkgTag && !sigTrue) nbb++;
00861          }
00862       }
00863       // Fill tagging probabilities
00864       fRulePTag.push_back(Double_t(ntag)/Double_t(neve));
00865       fRulePSS.push_back(Double_t(nss)/Double_t(ntag));
00866       fRulePSB.push_back(Double_t(nsb)/Double_t(ntag));
00867       fRulePBS.push_back(Double_t(nbs)/Double_t(ntag));
00868       fRulePBB.push_back(Double_t(nbb)/Double_t(ntag));
00869       //
00870    }
00871    fRuleFSig = static_cast<Double_t>(nsig)/static_cast<Double_t>(nsig+nbkg);
00872    for ( UInt_t v=0; v<nvars; v++) {
00873       fRuleVarFrac[v] =  Double_t(varcnt[v])/Double_t(nrules);
00874    }
00875 }
00876 
00877 //_______________________________________________________________________
00878 void TMVA::RuleEnsemble::RuleStatistics()
00879 {
00880    // calculate various statistics for this rule
00881    const UInt_t nrules = fRules.size();
00882    Double_t nc;
00883    Double_t sumNc =0;
00884    Double_t sumNc2=0;
00885    for ( UInt_t i=0; i<nrules; i++ ) {
00886       nc = static_cast<Double_t>(fRules[i]->GetNcuts());
00887       sumNc  += nc;
00888       sumNc2 += nc*nc;
00889    }
00890    fRuleNCave = 0.0;
00891    fRuleNCsig = 0.0;
00892    if (nrules>0) {
00893       fRuleNCave = sumNc/nrules;
00894       fRuleNCsig = TMath::Sqrt(gTools().ComputeVariance(sumNc2,sumNc,nrules));
00895    }
00896 }
00897 
00898 //_______________________________________________________________________
00899 void TMVA::RuleEnsemble::PrintRuleGen() const
00900 {
00901    // print rule generation info
00902    Log() << kINFO << "-------------------RULE ENSEMBLE SUMMARY------------------------"  << Endl;
00903    const MethodRuleFit *mrf = GetMethodRuleFit();
00904    if (mrf) Log() << kINFO << "Tree training method               : " << (mrf->UseBoost() ? "AdaBoost":"Random") << Endl;
00905    Log() << kINFO << "Number of events per tree          : " << fRuleFit->GetNTreeSample()    << Endl;
00906    Log() << kINFO << "Number of trees                    : " << fRuleFit->GetForest().size() << Endl;
00907    Log() << kINFO << "Number of generated rules          : " << fNRulesGenerated << Endl;
00908    Log() << kINFO << "Idem, after cleanup                : " << fRules.size() << Endl;
00909    Log() << kINFO << "Average number of cuts per rule    : " << Form("%8.2f",fRuleNCave) << Endl;
00910    Log() << kINFO << "Spread in number of cuts per rules : " << Form("%8.2f",fRuleNCsig) << Endl;
00911    Log() << kVERBOSE << "Complexity                         : " << Form("%8.2f",fRuleNCave*fRuleNCsig) << Endl;
00912    Log() << kINFO << "----------------------------------------------------------------"  << Endl;
00913    Log() << kINFO << Endl;
00914 }
00915 
00916 //_______________________________________________________________________
00917 void TMVA::RuleEnsemble::Print() const
00918 {
00919    // print function
00920 
00921    const EMsgType kmtype=kINFO;
00922    const Bool_t   isDebug = (fLogger->GetMinType()<=kDEBUG);
00923    //
00924    Log() << kmtype << Endl;
00925    Log() << kmtype << "================================================================" << Endl;
00926    Log() << kmtype << "                          M o d e l                             " << Endl;
00927    Log() << kmtype << "================================================================" << Endl;
00928 
00929    Int_t ind;
00930    const UInt_t nvars =  GetMethodBase()->GetNvar();
00931    const Int_t nrules = fRules.size();
00932    const Int_t printN = TMath::Min(10,nrules); //nrules+1;
00933    Int_t maxL = 0;
00934    for (UInt_t iv = 0; iv<fVarImportance.size(); iv++) {
00935       if (GetMethodBase()->GetInputLabel(iv).Length() > maxL) maxL = GetMethodBase()->GetInputLabel(iv).Length();
00936    }
00937    //
00938    if (isDebug) {
00939       Log() << kDEBUG << "Variable importance:" << Endl;
00940       for (UInt_t iv = 0; iv<fVarImportance.size(); iv++) {
00941          Log() << kDEBUG << std::setw(maxL) << GetMethodBase()->GetInputLabel(iv) 
00942                << std::resetiosflags(std::ios::right) 
00943                << " : " << Form(" %3.3f",fVarImportance[iv]) << Endl;
00944       }
00945    }
00946    //
00947    Log() << kmtype << "Offset (a0) = " << fOffset << Endl;
00948    //
00949    if (DoLinear()) {
00950       if (fLinNorm.size() > 0) {
00951          Log() << kmtype << "------------------------------------" << Endl;
00952          Log() << kmtype << "Linear model (weights unnormalised)" << Endl;
00953          Log() << kmtype << "------------------------------------" << Endl;
00954          Log() << kmtype << std::setw(maxL) << "Variable"
00955                << std::resetiosflags(std::ios::right) << " : "
00956                << std::setw(11) << " Weights"
00957                << std::resetiosflags(std::ios::right) << " : "
00958                << "Importance"
00959                << std::resetiosflags(std::ios::right)
00960                << Endl;
00961          Log() << kmtype << "------------------------------------" << Endl;
00962          for ( UInt_t i=0; i<fLinNorm.size(); i++ ) {
00963             Log() << kmtype << std::setw(std::max(maxL,8)) << GetMethodBase()->GetInputLabel(i);
00964             if (fLinTermOK[i]) {
00965                Log() << kmtype
00966                      << std::resetiosflags(std::ios::right)
00967                      << " : " << Form(" %10.3e",fLinCoefficients[i]*fLinNorm[i])
00968                      << " : " << Form(" %3.3f",fLinImportance[i]/fImportanceRef) << Endl;
00969             } 
00970             else {
00971                Log() << kmtype << "-> importance below threshhold = "
00972                      << Form(" %3.3f",fLinImportance[i]/fImportanceRef) << Endl;
00973             }
00974          }
00975          Log() << kmtype << "------------------------------------" << Endl;
00976       }
00977    } 
00978    else Log() << kmtype << "Linear terms were disabled" << Endl;
00979 
00980    if ((!DoRules()) || (nrules==0)) {
00981       if (!DoRules()) {
00982          Log() << kmtype << "Rule terms were disabled" << Endl;
00983       } 
00984       else {
00985          Log() << kmtype << "Eventhough rules were included in the model, none passed! " << nrules << Endl;
00986       }
00987    } 
00988    else {
00989       Log() << kmtype << "Number of rules = " << nrules << Endl;
00990       if (isDebug) {
00991          Log() << kmtype << "N(cuts) in rules, average = " << fRuleNCave << Endl;
00992          Log() << kmtype << "                      RMS = " << fRuleNCsig << Endl;
00993          Log() << kmtype << "Fraction of signal rules = " << fRuleFSig << Endl;
00994          Log() << kmtype << "Fraction of rules containing a variable (%):" << Endl;
00995          for ( UInt_t v=0; v<nvars; v++) {
00996             Log() << kmtype << "   " << std::setw(maxL) << GetMethodBase()->GetInputLabel(v);
00997             Log() << kmtype << Form(" = %2.2f",fRuleVarFrac[v]*100.0) << " %" << Endl;
00998          }
00999       }
01000       //
01001       // Print out all rules sorted in importance
01002       //
01003       std::list< std::pair<double,int> > sortedImp;
01004       for (Int_t i=0; i<nrules; i++) {
01005          sortedImp.push_back( std::pair<double,int>( fRules[i]->GetImportance(),i ) );
01006       }
01007       sortedImp.sort();
01008       //
01009       Log() << kmtype << "Printing the first " << printN << " rules, ordered in importance." << Endl;
01010       int pind=0;
01011       for ( std::list< std::pair<double,int> >::reverse_iterator itpair = sortedImp.rbegin();
01012             itpair != sortedImp.rend(); itpair++ ) {
01013          ind = itpair->second;
01014          //    if (pind==0) impref = 
01015          //         Log() << kmtype << "Rule #" << 
01016          //         Log() << kmtype << *fRules[ind] << Endl;
01017          fRules[ind]->PrintLogger(Form("Rule %4d : ",pind+1));
01018          pind++;
01019          if (pind==printN) {
01020             if (nrules==printN) {
01021                Log() << kmtype << "All rules printed" << Endl;
01022             } 
01023             else {
01024                Log() << kmtype << "Skipping the next " << nrules-printN << " rules" << Endl;
01025             }
01026             break;
01027          }
01028       }
01029    }
01030    Log() << kmtype << "================================================================" << Endl;
01031    Log() << kmtype << Endl;
01032 }
01033 
01034 //_______________________________________________________________________
01035 void TMVA::RuleEnsemble::PrintRaw( ostream & os ) const
01036 {
01037    // write rules to stream
01038    Int_t dp = os.precision();
01039    UInt_t nrules = fRules.size();
01040    //   std::sort(fRules.begin(),fRules.end());
01041    //
01042    os << "ImportanceCut= "    << fImportanceCut << std::endl;
01043    os << "LinQuantile= "      << fLinQuantile   << std::endl;
01044    os << "AverageSupport= "   << fAverageSupport << std::endl;
01045    os << "AverageRuleSigma= " << fAverageRuleSigma << std::endl;
01046    os << "Offset= "           << fOffset << std::endl;
01047    os << "NRules= "           << nrules << std::endl; 
01048    for (UInt_t i=0; i<nrules; i++){
01049       os << "***Rule " << i << std::endl;
01050       (fRules[i])->PrintRaw(os);
01051    }
01052    UInt_t nlinear = fLinNorm.size();
01053    //
01054    os << "NLinear= " << fLinTermOK.size() << std::endl;
01055    for (UInt_t i=0; i<nlinear; i++) {
01056       os << "***Linear " << i << std::endl;
01057       os << std::setprecision(10) << (fLinTermOK[i] ? 1:0) << " "
01058          << fLinCoefficients[i] << " "
01059          << fLinNorm[i] << " "
01060          << fLinDM[i] << " "
01061          << fLinDP[i] << " "
01062          << fLinImportance[i] << " " << std::endl;
01063    }
01064    os << std::setprecision(dp);
01065 }
01066 
01067 //_______________________________________________________________________
01068 void* TMVA::RuleEnsemble::AddXMLTo(void* parent) const
01069 {
01070    // write rules to XML
01071    void* re = gTools().AddChild( parent, "Weights" ); // this is the "RuleEnsemble"
01072 
01073    UInt_t nrules  = fRules.size();
01074    UInt_t nlinear = fLinNorm.size();
01075    gTools().AddAttr( re, "NRules",           nrules );
01076    gTools().AddAttr( re, "NLinear",          nlinear );
01077    gTools().AddAttr( re, "LearningModel",    (int)fLearningModel );
01078    gTools().AddAttr( re, "ImportanceCut",    fImportanceCut );
01079    gTools().AddAttr( re, "LinQuantile",      fLinQuantile );
01080    gTools().AddAttr( re, "AverageSupport",   fAverageSupport );
01081    gTools().AddAttr( re, "AverageRuleSigma", fAverageRuleSigma );
01082    gTools().AddAttr( re, "Offset",           fOffset );
01083    for (UInt_t i=0; i<nrules; i++) fRules[i]->AddXMLTo(re);
01084 
01085    for (UInt_t i=0; i<nlinear; i++) {
01086       void* lin = gTools().AddChild( re, "Linear" );
01087       gTools().AddAttr( lin, "OK",         (fLinTermOK[i] ? 1:0) );
01088       gTools().AddAttr( lin, "Coeff",      fLinCoefficients[i] );
01089       gTools().AddAttr( lin, "Norm",       fLinNorm[i] );
01090       gTools().AddAttr( lin, "DM",         fLinDM[i] );
01091       gTools().AddAttr( lin, "DP",         fLinDP[i] );
01092       gTools().AddAttr( lin, "Importance", fLinImportance[i] );
01093    }
01094    return re;
01095 }
01096 
01097 //_______________________________________________________________________
01098 void TMVA::RuleEnsemble::ReadFromXML( void* wghtnode ) 
01099 {
01100    // read rules from XML
01101    UInt_t nrules, nlinear;
01102    gTools().ReadAttr( wghtnode, "NRules",   nrules );
01103    gTools().ReadAttr( wghtnode, "NLinear",  nlinear );
01104    Int_t iLearningModel;
01105    gTools().ReadAttr( wghtnode, "LearningModel",     iLearningModel );
01106    fLearningModel =  (ELearningModel) iLearningModel;
01107    gTools().ReadAttr( wghtnode, "ImportanceCut",     fImportanceCut );
01108    gTools().ReadAttr( wghtnode, "LinQuantile",       fLinQuantile );
01109    gTools().ReadAttr( wghtnode, "AverageSupport",    fAverageSupport );
01110    gTools().ReadAttr( wghtnode, "AverageRuleSigma",  fAverageRuleSigma );
01111    gTools().ReadAttr( wghtnode, "Offset",            fOffset );
01112 
01113    // read rules
01114    DeleteRules();
01115 
01116    UInt_t i = 0;
01117    fRules.resize( nrules  );
01118    void* ch = gTools().GetChild( wghtnode );
01119    for (i=0; i<nrules; i++) {
01120       fRules[i] = new Rule();
01121       fRules[i]->SetRuleEnsemble( this );
01122       fRules[i]->ReadFromXML( ch );
01123 
01124       ch = gTools().GetNextChild(ch);
01125    }
01126 
01127    // read linear classifier (Fisher)
01128    fLinNorm        .resize( nlinear );
01129    fLinTermOK      .resize( nlinear );
01130    fLinCoefficients.resize( nlinear );
01131    fLinDP          .resize( nlinear );
01132    fLinDM          .resize( nlinear );
01133    fLinImportance  .resize( nlinear );
01134 
01135    Int_t iok;
01136    i=0;
01137    while(ch) {
01138       gTools().ReadAttr( ch, "OK",         iok );
01139       fLinTermOK[i] = (iok == 1);
01140       gTools().ReadAttr( ch, "Coeff",      fLinCoefficients[i]  );
01141       gTools().ReadAttr( ch, "Norm",       fLinNorm[i]          );
01142       gTools().ReadAttr( ch, "DM",         fLinDM[i]            );
01143       gTools().ReadAttr( ch, "DP",         fLinDP[i]            );
01144       gTools().ReadAttr( ch, "Importance", fLinImportance[i]    );
01145 
01146       i++;
01147       ch = gTools().GetNextChild(ch);
01148    }
01149 }
01150 
01151 //_______________________________________________________________________
01152 void TMVA::RuleEnsemble::ReadRaw( istream & istr )
01153 {
01154    // read rule ensemble from stream
01155    UInt_t nrules;
01156    //
01157    std::string dummy;
01158    Int_t idum;
01159    //
01160    // First block is general stuff
01161    //
01162    istr >> dummy >> fImportanceCut;
01163    istr >> dummy >> fLinQuantile;
01164    istr >> dummy >> fAverageSupport;
01165    istr >> dummy >> fAverageRuleSigma;
01166    istr >> dummy >> fOffset;
01167    istr >> dummy >> nrules;
01168    //
01169    // Now read in the rules
01170    //
01171    DeleteRules();
01172    //
01173    for (UInt_t i=0; i<nrules; i++){
01174       istr >> dummy >> idum; // read line  "***Rule <ind>"
01175       fRules.push_back( new Rule() );
01176       (fRules.back())->SetRuleEnsemble( this );
01177       (fRules.back())->ReadRaw(istr);
01178    }
01179    //
01180    // and now the linear terms
01181    //
01182    UInt_t nlinear;
01183    //
01184    // coverity[tainted_data_argument]
01185    istr >> dummy >> nlinear;
01186    //
01187    fLinNorm        .resize( nlinear );
01188    fLinTermOK      .resize( nlinear );
01189    fLinCoefficients.resize( nlinear );
01190    fLinDP          .resize( nlinear );
01191    fLinDM          .resize( nlinear );
01192    fLinImportance  .resize( nlinear );
01193    //
01194 
01195    Int_t iok;
01196    for (UInt_t i=0; i<nlinear; i++) {
01197       istr >> dummy >> idum;
01198       istr >> iok;
01199       fLinTermOK[i] = (iok==1);
01200       istr >> fLinCoefficients[i];
01201       istr >> fLinNorm[i];
01202       istr >> fLinDM[i];
01203       istr >> fLinDP[i];
01204       istr >> fLinImportance[i];
01205    }
01206 }
01207 
01208 //_______________________________________________________________________
01209 void TMVA::RuleEnsemble::Copy( const RuleEnsemble & other )
01210 {
01211    // copy function
01212    if(this != &other) {
01213       fRuleFit           = other.GetRuleFit();
01214       fRuleMinDist       = other.GetRuleMinDist();
01215       fOffset            = other.GetOffset();
01216       fRules             = other.GetRulesConst();
01217       fImportanceCut     = other.GetImportanceCut();
01218       fVarImportance     = other.GetVarImportance();
01219       fLearningModel     = other.GetLearningModel();
01220       fLinQuantile       = other.GetLinQuantile();
01221       fRuleNCsig         = other.fRuleNCsig;
01222       fAverageRuleSigma  = other.fAverageRuleSigma;
01223       fEventCacheOK      = other.fEventCacheOK;
01224       fImportanceRef     = other.fImportanceRef;
01225       fNRulesGenerated   = other.fNRulesGenerated;
01226       fRuleFSig          = other.fRuleFSig;
01227       fRuleMapInd0       = other.fRuleMapInd0;
01228       fRuleMapInd1       = other.fRuleMapInd1;
01229       fRuleMapOK         = other.fRuleMapOK;
01230       fRuleNCave         = other.fRuleNCave;
01231    }
01232 }
01233 
01234 //_______________________________________________________________________
01235 Int_t TMVA::RuleEnsemble::CalcNRules( const DecisionTree *dtree )
01236 {
01237    // calculate the number of rules
01238    if (dtree==0) return 0;
01239    Node *node = dtree->GetRoot();
01240    Int_t nendnodes = 0;
01241    FindNEndNodes( node, nendnodes );
01242    return 2*(nendnodes-1);
01243 }
01244 
01245 //_______________________________________________________________________
01246 void TMVA::RuleEnsemble::FindNEndNodes( const Node *node, Int_t & nendnodes )
01247 {
01248    // find the number of leaf nodes
01249 
01250    if (node==0) return;
01251    if ((node->GetRight()==0) && (node->GetLeft()==0)) {
01252       ++nendnodes;
01253       return;
01254    }
01255    const Node *nodeR = node->GetRight();
01256    const Node *nodeL = node->GetLeft();
01257    FindNEndNodes( nodeR, nendnodes );
01258    FindNEndNodes( nodeL, nendnodes );
01259 }
01260 
01261 //_______________________________________________________________________
01262 void TMVA::RuleEnsemble::MakeRulesFromTree( const DecisionTree *dtree )
01263 {
01264    // create rules from the decsision tree structure 
01265    Node *node = dtree->GetRoot();
01266    AddRule( node );
01267 }
01268 
01269 //_______________________________________________________________________
01270 void TMVA::RuleEnsemble::AddRule( const Node *node )
01271 {
01272    // add a new rule to the tree
01273 
01274    if (node==0) return;
01275    if (node->GetParent()==0) { // it's a root node, don't make a rule
01276       AddRule( node->GetRight() );
01277       AddRule( node->GetLeft() );
01278    } 
01279    else {
01280       Rule *rule = MakeTheRule(node);
01281       if (rule) {
01282          fRules.push_back( rule );
01283          AddRule( node->GetRight() );
01284          AddRule( node->GetLeft() );
01285       } 
01286       else {
01287          Log() << kFATAL << "<AddRule> - ERROR failed in creating a rule! BUG!" << Endl;
01288       }
01289    }
01290 }
01291 
01292 //_______________________________________________________________________
01293 TMVA::Rule *TMVA::RuleEnsemble::MakeTheRule( const Node *node )
01294 {
01295    //
01296    // Make a Rule from a given Node.
01297    // The root node (ie no parent) does not generate a Rule.
01298    // The first node in a rule is always the root node => fNodes.size()>=2
01299    // Each node corresponds to a cut and the cut value is given by the parent node.
01300    //
01301    //
01302    if (node==0) {
01303       Log() << kFATAL << "<MakeTheRule> Input node is NULL. Should not happen. BUG!" << Endl;
01304       return 0;
01305    }
01306 
01307    if (node->GetParent()==0) { // a root node - ignore
01308       return 0;
01309    }
01310    //
01311    std::vector< const Node * > nodeVec;
01312    const Node *parent = node;
01313    //
01314    // Make list with the input node at the end:
01315    // <root node> <node1> <node2> ... <node given as argument>
01316    // 
01317    nodeVec.push_back( node );
01318    while (parent!=0) {
01319       parent = parent->GetParent();
01320       if (!parent) continue;
01321       const DecisionTreeNode* dtn = dynamic_cast<const DecisionTreeNode*>(parent);
01322       if (dtn && dtn->GetSelector()>=0)
01323          nodeVec.insert( nodeVec.begin(), parent );
01324 
01325    }
01326    if (nodeVec.size()<2) {
01327       Log() << kFATAL << "<MakeTheRule> BUG! Inconsistent Rule!" << Endl;
01328       return 0;
01329    }
01330    Rule *rule = new Rule( this, nodeVec );
01331    rule->SetMsgType( Log().GetMinType() );
01332    return rule;
01333 }
01334 
01335 //_______________________________________________________________________
01336 void TMVA::RuleEnsemble::MakeRuleMap(const std::vector<Event *> *events, UInt_t ifirst, UInt_t ilast)
01337 {
01338    // Makes rule map for all events
01339 
01340    Log() << kVERBOSE << "Making Rule map for all events" << Endl;
01341    // make rule response map
01342    if (events==0) events = GetTrainingEvents();
01343    if ((ifirst==0) || (ilast==0) || (ifirst>ilast)) {
01344       ifirst = 0;
01345       ilast  = events->size()-1;
01346    }
01347    // check if identical to previous call
01348    if ((events!=fRuleMapEvents) ||
01349        (ifirst!=fRuleMapInd0) ||
01350        (ilast !=fRuleMapInd1)) {
01351       fRuleMapOK = kFALSE;
01352    }
01353    //
01354    if (fRuleMapOK) {
01355       Log() << kVERBOSE << "<MakeRuleMap> Map is already valid" << Endl;
01356       return;  // already cached
01357    }
01358    fRuleMapEvents = events;
01359    fRuleMapInd0   = ifirst;
01360    fRuleMapInd1   = ilast;
01361    // check number of rules
01362    UInt_t nrules = GetNRules(); 
01363    if (nrules==0) {
01364       Log() << kVERBOSE << "No rules found in MakeRuleMap()" << Endl;
01365       fRuleMapOK = kTRUE;
01366       return;
01367    }
01368    //
01369    // init map
01370    //
01371    std::vector<UInt_t> ruleind;
01372    fRuleMap.clear();
01373    for (UInt_t i=ifirst; i<=ilast; i++) {
01374       ruleind.clear();
01375       fRuleMap.push_back( ruleind );
01376       for (UInt_t r=0; r<nrules; r++) {
01377          if (fRules[r]->EvalEvent(*((*events)[i]))) {
01378             fRuleMap.back().push_back(r); // save only rules that are accepted
01379          }
01380       }
01381    }
01382    fRuleMapOK = kTRUE;
01383    Log() << kVERBOSE << "Made rule map for event# " << ifirst << " : " << ilast << Endl;
01384 }
01385 
01386 //_______________________________________________________________________
01387 ostream& TMVA::operator<< ( ostream& os, const RuleEnsemble & rules )
01388 {
01389    // ostream operator
01390    os << "DON'T USE THIS - TO BE REMOVED" << std::endl;
01391    rules.Print();
01392    return os;
01393 }

Generated on Tue Jul 5 15:25:35 2011 for ROOT_528-00b_version by  doxygen 1.5.1