RuleFit.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: RuleFit.cxx 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Fredrik Tegenfeldt, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : Rule                                                                  *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      A class describung a 'rule'                                               *
00012  *      Each internal node of a tree defines a rule from all the parental nodes.  *
00013  *      A rule with 0 or 1 nodes in the list is a root rule -> corresponds to a0. *
00014  *      Input: a decision tree (in the constructor)                               *
00015  *             its coefficient                                                    *
00016  *                                                                                *
00017  *                                                                                *
00018  * Authors (alphabetical):                                                        *
00019  *      Fredrik Tegenfeldt <Fredrik.Tegenfeldt@cern.ch> - Iowa State U., USA      *
00020  *                                                                                *
00021  * Copyright (c) 2005:                                                            *
00022  *      CERN, Switzerland                                                         *
00023  *      Iowa State U.                                                             *
00024  *      MPI-K Heidelberg, Germany                                                 *
00025  *                                                                                *
00026  * Redistribution and use in source and binary forms, with or without             *
00027  * modification, are permitted according to the terms listed in LICENSE           *
00028  * (http://tmva.sourceforge.net/LICENSE)                                          *
00029  **********************************************************************************/
00030 
00031 #include <algorithm>
00032 
00033 #include "TKey.h"
00034 #include "TRandom3.h"
00035 
00036 #include "TMVA/SeparationBase.h"
00037 #include "TMVA/GiniIndex.h"
00038 #include "TMVA/RuleFit.h"
00039 #include "TMVA/MethodRuleFit.h"
00040 #include "TMVA/Timer.h"
00041 #include "TMVA/Tools.h"
00042 #include "TMVA/Factory.h" // for root base dir
00043 
00044 ClassImp(TMVA::RuleFit)
00045 
00046 //_______________________________________________________________________
00047 TMVA::RuleFit::RuleFit( const MethodBase *rfbase )
00048    : fVisHistsUseImp( kTRUE ),
00049      fLogger( new MsgLogger("RuleFit") )
00050 {
00051    // constructor
00052    Initialize( rfbase );
00053    std::srand( randSEED );  // initialize random number generator used by std::random_shuffle
00054 }
00055 
00056 //_______________________________________________________________________
00057 TMVA::RuleFit::RuleFit()
00058    : fNTreeSample(0)
00059    , fNEveEffTrain(0)
00060    , fMethodRuleFit(0)
00061    , fMethodBase(0)
00062    , fVisHistsUseImp( kTRUE )
00063    , fLogger( new MsgLogger("RuleFit") )
00064 {
00065    // default constructor
00066    std::srand( randSEED ); // initialize random number generator used by std::random_shuffle
00067 }
00068 
00069 //_______________________________________________________________________
00070 TMVA::RuleFit::~RuleFit()
00071 {
00072    // destructor
00073    delete fLogger;
00074 }
00075 
00076 //_______________________________________________________________________
00077 void TMVA::RuleFit::InitNEveEff()
00078 {
00079    // init effective number of events (using event weights)
00080    UInt_t neve = fTrainingEvents.size();
00081    if (neve==0) return;
00082    //
00083    fNEveEffTrain = CalcWeightSum( &fTrainingEvents );
00084    //
00085 }
00086 
00087 //_______________________________________________________________________
00088 void TMVA::RuleFit::InitPtrs(  const MethodBase *rfbase )
00089 {
00090    // initialize pointers
00091    this->SetMethodBase(rfbase);
00092    fRuleEnsemble.Initialize( this );
00093    fRuleFitParams.SetRuleFit( this );
00094 }
00095 
00096 //_______________________________________________________________________
00097 void TMVA::RuleFit::Initialize(  const MethodBase *rfbase )
00098 {
00099    // initialize the parameters of the RuleFit method and make rules
00100    InitPtrs(rfbase);
00101 
00102    if (fMethodRuleFit) 
00103       SetTrainingEvents( fMethodRuleFit->GetTrainingEvents() );
00104 
00105    InitNEveEff();
00106 
00107    MakeForest();
00108 
00109    // Make the model - Rule + Linear (if fDoLinear is true)
00110    fRuleEnsemble.MakeModel();
00111 
00112    // init rulefit params
00113    fRuleFitParams.Init();
00114 
00115 }
00116 
00117 //_______________________________________________________________________
00118 void TMVA::RuleFit::SetMethodBase( const MethodBase *rfbase )
00119 {
00120    // set MethodBase
00121    fMethodBase = rfbase;
00122    fMethodRuleFit = dynamic_cast<const MethodRuleFit *>(rfbase);
00123 }
00124 
00125 //_______________________________________________________________________
00126 void TMVA::RuleFit::Copy( const RuleFit& other )
00127 {
00128    // copy method
00129    if(this != &other) {
00130       fMethodRuleFit   = other.GetMethodRuleFit();
00131       fMethodBase      = other.GetMethodBase();
00132       fTrainingEvents  = other.GetTrainingEvents();
00133       //      fSubsampleEvents = other.GetSubsampleEvents();
00134    
00135       fForest       = other.GetForest();
00136       fRuleEnsemble = other.GetRuleEnsemble();
00137    }
00138 }
00139 
00140 //_______________________________________________________________________
00141 Double_t TMVA::RuleFit::CalcWeightSum( const std::vector<Event *> *events, UInt_t neve )
00142 {
00143    // calculate the sum of weights
00144    if (events==0) return 0.0;
00145    if (neve==0) neve=events->size();
00146    //
00147    Double_t sumw=0;
00148    for (UInt_t ie=0; ie<neve; ie++) {
00149       sumw += ((*events)[ie])->GetWeight();
00150    }
00151    return sumw;
00152 }
00153 
00154 //_______________________________________________________________________
00155 void TMVA::RuleFit::SetMsgType( EMsgType t )
00156 {
00157    // set the current message type to that of mlog for this class and all other subtools
00158    fLogger->SetMinType(t);
00159    fRuleEnsemble.SetMsgType(t);
00160    fRuleFitParams.SetMsgType(t);
00161 }
00162 
00163 //_______________________________________________________________________
00164 void TMVA::RuleFit::BuildTree( DecisionTree *dt )
00165 {
00166    // build the decision tree using fNTreeSample events from fTrainingEventsRndm
00167    if (dt==0) return;
00168    if (fMethodRuleFit==0) {
00169       Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
00170    }
00171    std::vector<Event *> evevec;
00172    for (UInt_t ie=0; ie<fNTreeSample; ie++) {
00173       evevec.push_back(fTrainingEventsRndm[ie]);
00174    }
00175    dt->BuildTree(evevec);
00176    if (fMethodRuleFit->GetPruneMethod() != DecisionTree::kNoPruning) {
00177       dt->SetPruneMethod(fMethodRuleFit->GetPruneMethod());
00178       dt->SetPruneStrength(fMethodRuleFit->GetPruneStrength());
00179       dt->PruneTree();
00180    }
00181 }
00182 
00183 //_______________________________________________________________________
00184 void TMVA::RuleFit::MakeForest()
00185 {
00186    // make a forest of decisiontrees
00187    if (fMethodRuleFit==0) {
00188       Log() << kFATAL << "RuleFit::BuildTree() - Attempting to build a tree NOT from a MethodRuleFit" << Endl;
00189    }
00190    Log() << kDEBUG << "Creating a forest with " << fMethodRuleFit->GetNTrees() << " decision trees" << Endl;
00191    Log() << kDEBUG << "Each tree is built using a random subsample with " << fNTreeSample << " events" << Endl;
00192    //
00193    Timer timer( fMethodRuleFit->GetNTrees(), "RuleFit" );
00194 
00195    Double_t fsig;
00196    Int_t nsig,nbkg;
00197    //
00198    TRandom3 rndGen;
00199    //
00200    Int_t nminRnd;
00201    //
00202    // First save all event weights.
00203    // Weights are modifed by the boosting.
00204    // Those weights we do not want for the later fitting.
00205    //
00206    Bool_t useBoost = fMethodRuleFit->UseBoost(); // (AdaBoost (True) or RandomForest/Tree (False)
00207 
00208    if (useBoost) SaveEventWeights();
00209 
00210    for (Int_t i=0; i<fMethodRuleFit->GetNTrees(); i++) {
00211       //      timer.DrawProgressBar(i);
00212       if (!useBoost) ReshuffleEvents();
00213       nsig=0;
00214       nbkg=0;
00215       for (UInt_t ie = 0; ie<fNTreeSample; ie++) {
00216          if (fMethodBase->DataInfo().IsSignal(fTrainingEventsRndm[ie])) nsig++; // ignore weights here
00217          else nbkg++;
00218       }
00219       fsig = Double_t(nsig)/Double_t(nsig+nbkg);
00220       // do not implement the above in this release...just set it to default
00221       //      nminRnd = fNodeMinEvents;
00222       DecisionTree *dt;
00223       Bool_t tryAgain=kTRUE;
00224       Int_t ntries=0;
00225       const Int_t ntriesMax=10;
00226       while (tryAgain) {
00227          Double_t frnd = rndGen.Uniform( fMethodRuleFit->GetMinFracNEve(), fMethodRuleFit->GetMaxFracNEve() );
00228          nminRnd = Int_t(frnd*static_cast<Double_t>(fNTreeSample));
00229          Int_t     iclass = 0; // event class being treated as signal during training
00230          Bool_t    useRandomisedTree = !useBoost;  
00231          dt = new DecisionTree( fMethodRuleFit->GetSeparationBase(), nminRnd, fMethodRuleFit->GetNCuts(), iclass, useRandomisedTree);
00232 
00233          BuildTree(dt); // reads fNTreeSample events from fTrainingEventsRndm
00234          if (dt->GetNNodes()<3) {
00235             delete dt;
00236             dt=0;
00237          }
00238          ntries++;
00239          tryAgain = ((dt==0) && (ntries<ntriesMax));
00240       }
00241       if (dt) {
00242          fForest.push_back(dt);
00243          if (useBoost) Boost(dt);
00244 
00245       } else {
00246 
00247          Log() << kWARNING << "------------------------------------------------------------------" << Endl;
00248          Log() << kWARNING << " Failed growing a tree even after " << ntriesMax << " trials" << Endl;
00249          Log() << kWARNING << " Possible solutions: " << Endl;
00250          Log() << kWARNING << "   1. increase the number of training events" << Endl;
00251          Log() << kWARNING << "   2. set a lower min fraction cut (fEventsMin)" << Endl;
00252          Log() << kWARNING << "   3. maybe also decrease the max fraction cut (fEventsMax)" << Endl;
00253          Log() << kWARNING << " If the above warning occurs rarely only, it can be ignored" << Endl;
00254          Log() << kWARNING << "------------------------------------------------------------------" << Endl;
00255       }
00256 
00257       Log() << kDEBUG << "Built tree with minimum cut at N = " << nminRnd
00258               << " => N(nodes) = " << fForest.back()->GetNNodes()
00259               << " ; n(tries) = " << ntries
00260               << Endl;
00261    }
00262 
00263    // Now restore event weights
00264    if (useBoost) RestoreEventWeights();
00265 
00266    // print statistics on the forest created
00267    ForestStatistics();
00268 }
00269 
00270 //_______________________________________________________________________
00271 void TMVA::RuleFit::SaveEventWeights()
00272 {
00273    // save event weights - must be done before making the forest
00274    fEventWeights.clear();
00275    for (std::vector<Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); e++) {
00276       Double_t w = (*e)->GetWeight();
00277       fEventWeights.push_back(w);
00278    }
00279 }
00280 
00281 //_______________________________________________________________________
00282 void TMVA::RuleFit::RestoreEventWeights()
00283 {
00284    // save event weights - must be done before making the forest
00285    UInt_t ie=0;
00286    if (fEventWeights.size() != fTrainingEvents.size()) {
00287       Log() << kERROR << "RuleFit::RestoreEventWeights() called without having called SaveEventWeights() before!" << Endl;
00288       return;
00289    }
00290    for (std::vector<Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); e++) {
00291       (*e)->SetWeight(fEventWeights[ie]);
00292       ie++;
00293    }
00294 }
00295 
00296 //_______________________________________________________________________
00297 void TMVA::RuleFit::Boost( DecisionTree *dt )
00298 {
00299    // Boost the events. The algorithm below is the called AdaBoost.
00300    // See MethodBDT for details.
00301    // Actually, this is a more or less copy of MethodBDT::AdaBoost().
00302    Double_t sumw=0;      // sum of initial weights - all events
00303    Double_t sumwfalse=0; // idem, only missclassified events
00304    //
00305    std::vector<Char_t> correctSelected; // <--- boolean stored
00306    //
00307    for (std::vector<Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); e++) {
00308       Bool_t isSignalType = (dt->CheckEvent(*(*e),kTRUE) > 0.5 );
00309       Double_t w = (*e)->GetWeight();
00310       sumw += w;
00311       // 
00312       if (isSignalType == fMethodBase->DataInfo().IsSignal(*e)) { // correctly classified
00313          correctSelected.push_back(kTRUE);
00314       } 
00315       else {                                // missclassified
00316          sumwfalse+= w;
00317          correctSelected.push_back(kFALSE);
00318       }    
00319    }
00320    // missclassification error
00321    Double_t err = sumwfalse/sumw;
00322    // calculate boost weight for missclassified events
00323    // use for now the exponent = 1.0
00324    // one could have w = ((1-err)/err)^beta
00325    Double_t boostWeight = (err>0 ? (1.0-err)/err : 1000.0);
00326    Double_t newSumw=0.0;
00327    UInt_t ie=0;
00328    // set new weight to missclassified events
00329    for (std::vector<Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); e++) {
00330       if (!correctSelected[ie])
00331          (*e)->SetWeight( (*e)->GetWeight() * boostWeight);
00332       newSumw+=(*e)->GetWeight();    
00333       ie++;
00334    }
00335    // reweight all events
00336    Double_t scale = sumw/newSumw;
00337    for (std::vector<Event*>::iterator e=fTrainingEvents.begin(); e!=fTrainingEvents.end(); e++) {
00338       (*e)->SetWeight( (*e)->GetWeight() * scale);
00339    }
00340    Log() << kDEBUG << "boostWeight = " << boostWeight << "    scale = " << scale << Endl;
00341 }
00342 
00343 //_______________________________________________________________________
00344 void TMVA::RuleFit::ForestStatistics()
00345 {
00346    // summary of statistics of all trees
00347    // * end-nodes: average and spread
00348    UInt_t ntrees = fForest.size();
00349    if (ntrees==0) return;
00350    const DecisionTree *tree;
00351    Double_t sumn2 = 0;
00352    Double_t sumn  = 0;
00353    Double_t nd;
00354    for (UInt_t i=0; i<ntrees; i++) {
00355       tree = fForest[i];
00356       nd = Double_t(tree->GetNNodes());
00357       sumn  += nd;
00358       sumn2 += nd*nd;
00359    }
00360    Double_t sig = TMath::Sqrt( gTools().ComputeVariance( sumn2, sumn, ntrees ));
00361    Log() << kVERBOSE << "Nodes in trees: average & std dev = " << sumn/ntrees << " , " << sig << Endl;
00362 }
00363 
00364 //_______________________________________________________________________
00365 void TMVA::RuleFit::FitCoefficients()
00366 {
00367    //
00368    // Fit the coefficients for the rule ensemble
00369    //
00370    Log() << kVERBOSE << "Fitting rule/linear terms" << Endl;
00371    fRuleFitParams.MakeGDPath();
00372 }
00373 
00374 //_______________________________________________________________________
00375 void TMVA::RuleFit::CalcImportance()
00376 {
00377    // calculates the importance of each rule
00378 
00379    Log() << kVERBOSE << "Calculating importance" << Endl;
00380    fRuleEnsemble.CalcImportance();
00381    fRuleEnsemble.CleanupRules();
00382    fRuleEnsemble.CleanupLinear();
00383    fRuleEnsemble.CalcVarImportance();
00384    Log() << kVERBOSE << "Filling rule statistics" << Endl;
00385    fRuleEnsemble.RuleResponseStats();
00386 }
00387 
00388 //_______________________________________________________________________
00389 Double_t TMVA::RuleFit::EvalEvent( const Event& e )
00390 {
00391    // evaluate single event
00392 
00393    return fRuleEnsemble.EvalEvent( e );
00394 }
00395 
00396 //_______________________________________________________________________
00397 void TMVA::RuleFit::SetTrainingEvents( const std::vector<Event *>& el )
00398 {
00399    // set the training events randomly
00400    if (fMethodRuleFit==0) Log() << kFATAL << "RuleFit::SetTrainingEvents - MethodRuleFit not initialized" << Endl;
00401    UInt_t neve = el.size();
00402    if (neve==0) Log() << kWARNING << "An empty sample of training events was given" << Endl;
00403 
00404    // copy vector
00405    fTrainingEvents.clear();
00406    fTrainingEventsRndm.clear();
00407    for (UInt_t i=0; i<neve; i++) {
00408       fTrainingEvents.push_back(static_cast< Event *>(el[i]));
00409       fTrainingEventsRndm.push_back(static_cast< Event *>(el[i]));
00410    }
00411 
00412    // Re-shuffle the vector, ie, recreate it in a random order
00413    std::random_shuffle( fTrainingEventsRndm.begin(), fTrainingEventsRndm.end() );
00414 
00415    // fraction events per tree
00416    fNTreeSample = static_cast<UInt_t>(neve*fMethodRuleFit->GetTreeEveFrac());
00417    Log() << kDEBUG << "Number of events per tree : " << fNTreeSample
00418            << " ( N(events) = " << neve << " )"
00419            << " randomly drawn without replacement" << Endl;
00420 }
00421 
00422 //_______________________________________________________________________
00423 void TMVA::RuleFit::GetRndmSampleEvents(std::vector< const Event * > & evevec, UInt_t nevents)
00424 {
00425    // draw a random subsample of the training events without replacement
00426    ReshuffleEvents();
00427    if ((nevents<fTrainingEventsRndm.size()) && (nevents>0)) {
00428       evevec.resize(nevents);
00429       for (UInt_t ie=0; ie<nevents; ie++) {
00430          evevec[ie] = fTrainingEventsRndm[ie];
00431       }
00432    } 
00433    else {
00434       Log() << kWARNING << "GetRndmSampleEvents() : requested sub sample size larger than total size (BUG!).";
00435    }
00436 }
00437 //_______________________________________________________________________
00438 void TMVA::RuleFit::NormVisHists(std::vector<TH2F *> & hlist)
00439 {
00440    // normalize rule importance hists
00441    //
00442    // if all weights are positive, the scale will be 1/maxweight
00443    // if minimum weight < 0, then the scale will be 1/max(maxweight,abs(minweight))
00444    //
00445    if (hlist.size()==0) return;
00446    //
00447    Double_t wmin=0;
00448    Double_t wmax=0;
00449    Double_t w,wm;
00450    Double_t awmin;
00451    Double_t scale;
00452    for (UInt_t i=0; i<hlist.size(); i++) {
00453       TH2F *hs = hlist[i];
00454       w  = hs->GetMaximum();
00455       wm = hs->GetMinimum();
00456       if (i==0) {
00457          wmin=wm;
00458          wmax=w;
00459       } 
00460       else {
00461          if (w>wmax)  wmax=w;
00462          if (wm<wmin) wmin=wm;
00463       }
00464    }
00465    awmin = TMath::Abs(wmin);
00466    Double_t usemin,usemax;
00467    if (awmin>wmax) {
00468       scale = 1.0/awmin;
00469       usemin = -1.0;
00470       usemax = scale*wmax;
00471    } 
00472    else {
00473       scale = 1.0/wmax;
00474       usemin = scale*wmin;
00475       usemax = 1.0;
00476    }
00477    
00478    //
00479    for (UInt_t i=0; i<hlist.size(); i++) {
00480       TH2F *hs = hlist[i];
00481       hs->Scale(scale);
00482       hs->SetMinimum(usemin);
00483       hs->SetMaximum(usemax);
00484    }
00485 }
00486 
00487 //_______________________________________________________________________
00488 void TMVA::RuleFit::FillCut(TH2F* h2, const Rule *rule, Int_t vind)
00489 {
00490    // Fill cut
00491 
00492    if (rule==0) return;
00493    if (h2==0) return;
00494    //
00495    Double_t rmin,  rmax;
00496    Bool_t   dormin,dormax;
00497    Bool_t ruleHasVar = rule->GetRuleCut()->GetCutRange(vind,rmin,rmax,dormin,dormax);
00498    if (!ruleHasVar) return;
00499    //
00500    Int_t firstbin = h2->GetBin(1,1,1);
00501    if(firstbin<0) firstbin=0;
00502    Int_t lastbin = h2->GetBin(h2->GetNbinsX(),1,1);
00503    Int_t binmin=(dormin ? h2->FindBin(rmin,0.5):firstbin);
00504    Int_t binmax=(dormax ? h2->FindBin(rmax,0.5):lastbin);
00505    Int_t fbin;
00506    Double_t xbinw = h2->GetBinWidth(firstbin);
00507    Double_t fbmin = h2->GetBinLowEdge(binmin-firstbin+1);
00508    Double_t lbmax = h2->GetBinLowEdge(binmax-firstbin+1)+xbinw;
00509    Double_t fbfrac = (dormin ? ((fbmin+xbinw-rmin)/xbinw):1.0);
00510    Double_t lbfrac = (dormax ? ((rmax-lbmax+xbinw)/xbinw):1.0);
00511    Double_t f;
00512    Double_t xc;
00513    Double_t val;
00514 
00515    for (Int_t bin = binmin; bin<binmax+1; bin++) {
00516       fbin = bin-firstbin+1;
00517       if (bin==binmin) {
00518          f = fbfrac;
00519       }
00520       else if (bin==binmax) {
00521          f = lbfrac;
00522       }
00523       else {
00524          f = 1.0;
00525       }
00526       xc = h2->GetBinCenter(fbin);
00527       //
00528       if (fVisHistsUseImp) {
00529          val = rule->GetImportance();
00530       } 
00531       else {
00532          val = rule->GetCoefficient()*rule->GetSupport();
00533       }
00534       h2->Fill(xc,0.5,val*f);
00535    }
00536 }
00537 
00538 //_______________________________________________________________________
00539 void TMVA::RuleFit::FillLin(TH2F* h2,Int_t vind)
00540 {
00541    // fill lin
00542    if (h2==0) return;
00543    if (!fRuleEnsemble.DoLinear()) return;
00544    //
00545    Int_t firstbin = 1;
00546    Int_t lastbin = h2->GetNbinsX();
00547    Double_t xc;
00548    Double_t val;
00549    if (fVisHistsUseImp) {
00550       val = fRuleEnsemble.GetLinImportance(vind);
00551    }
00552    else {
00553       val = fRuleEnsemble.GetLinCoefficients(vind);
00554    }
00555    for (Int_t bin = firstbin; bin<lastbin+1; bin++) {
00556       xc = h2->GetBinCenter(bin);
00557       h2->Fill(xc,0.5,val);
00558    }
00559 }
00560 
00561 //_______________________________________________________________________
00562 void TMVA::RuleFit::FillCorr(TH2F* h2,const Rule *rule,Int_t vx, Int_t vy)
00563 {
00564    // fill rule correlation between vx and vy, weighted with either the importance or the coefficient
00565    if (rule==0) return;
00566    if (h2==0) return;
00567    Double_t val;
00568    if (fVisHistsUseImp) {
00569       val = rule->GetImportance();
00570    }
00571    else {
00572       val = rule->GetCoefficient()*rule->GetSupport();
00573    }
00574    //
00575    Double_t rxmin,   rxmax,   rymin,   rymax;
00576    Bool_t   dorxmin, dorxmax, dorymin, dorymax;
00577    //
00578    // Get range in rule for X and Y
00579    //
00580    Bool_t ruleHasVarX = rule->GetRuleCut()->GetCutRange(vx,rxmin,rxmax,dorxmin,dorxmax);
00581    Bool_t ruleHasVarY = rule->GetRuleCut()->GetCutRange(vy,rymin,rymax,dorymin,dorymax);
00582    if (!(ruleHasVarX || ruleHasVarY)) return;
00583    // min max of varX and varY in hist
00584    Double_t vxmin = (dorxmin ? rxmin:h2->GetXaxis()->GetXmin());
00585    Double_t vxmax = (dorxmax ? rxmax:h2->GetXaxis()->GetXmax());
00586    Double_t vymin = (dorymin ? rymin:h2->GetYaxis()->GetXmin());
00587    Double_t vymax = (dorymax ? rymax:h2->GetYaxis()->GetXmax());
00588    // min max bin in X and Y
00589    Int_t binxmin  = h2->GetXaxis()->FindBin(vxmin);
00590    Int_t binxmax  = h2->GetXaxis()->FindBin(vxmax);
00591    Int_t binymin  = h2->GetYaxis()->FindBin(vymin);
00592    Int_t binymax  = h2->GetYaxis()->FindBin(vymax);
00593    // bin widths
00594    Double_t xbinw = h2->GetXaxis()->GetBinWidth(binxmin);
00595    Double_t ybinw = h2->GetYaxis()->GetBinWidth(binxmin);
00596    Double_t xbinmin = h2->GetXaxis()->GetBinLowEdge(binxmin);
00597    Double_t xbinmax = h2->GetXaxis()->GetBinLowEdge(binxmax)+xbinw;
00598    Double_t ybinmin = h2->GetYaxis()->GetBinLowEdge(binymin);
00599    Double_t ybinmax = h2->GetYaxis()->GetBinLowEdge(binymax)+ybinw;
00600    // fraction of edges
00601    Double_t fxbinmin = (dorxmin ? ((xbinmin+xbinw-vxmin)/xbinw):1.0);
00602    Double_t fxbinmax = (dorxmax ? ((vxmax-xbinmax+xbinw)/xbinw):1.0);
00603    Double_t fybinmin = (dorymin ? ((ybinmin+ybinw-vymin)/ybinw):1.0);
00604    Double_t fybinmax = (dorymax ? ((vymax-ybinmax+ybinw)/ybinw):1.0);
00605    //
00606    Double_t fx,fy;
00607    Double_t xc,yc;
00608    // fill histo
00609    for (Int_t binx = binxmin; binx<binxmax+1; binx++) {
00610       if (binx==binxmin) {
00611          fx = fxbinmin;
00612       } 
00613       else if (binx==binxmax) {
00614          fx = fxbinmax;
00615       } 
00616       else {
00617          fx = 1.0;
00618       }
00619       xc = h2->GetXaxis()->GetBinCenter(binx);
00620       for (Int_t biny = binymin; biny<binymax+1; biny++) {
00621          if (biny==binymin) {
00622             fy = fybinmin;
00623          } 
00624          else if (biny==binymax) {
00625             fy = fybinmax;
00626          } 
00627          else {
00628             fy = 1.0;
00629          }
00630          yc = h2->GetYaxis()->GetBinCenter(biny);
00631          h2->Fill(xc,yc,val*fx*fy);
00632       }
00633    }
00634 }
00635 
00636 //_______________________________________________________________________
00637 void TMVA::RuleFit::FillVisHistCut(const Rule* rule, std::vector<TH2F *> & hlist)
00638 {
00639    // help routine to MakeVisHists() - fills for all variables
00640    Int_t nhists = hlist.size();
00641    Int_t nvar   = fMethodBase->GetNvar();
00642    if (nhists!=nvar) Log() << kFATAL << "BUG TRAP: number of hists is not equal the number of variables!" << Endl;
00643    //
00644    std::vector<Int_t> vindex;
00645    TString hstr;
00646    // not a nice way to do a check...
00647    for (Int_t ih=0; ih<nhists; ih++) {
00648       hstr = hlist[ih]->GetTitle();
00649       for (Int_t iv=0; iv<nvar; iv++) {
00650          if (fMethodBase->GetInputTitle(iv) == hstr)
00651             vindex.push_back(iv);
00652       }
00653    }
00654    //
00655    for (Int_t iv=0; iv<nvar; iv++) {
00656       if (rule) {
00657          if (rule->ContainsVariable(vindex[iv])) {
00658             FillCut(hlist[iv],rule,vindex[iv]);
00659          }
00660       } 
00661       else {
00662          FillLin(hlist[iv],vindex[iv]);
00663       }
00664    }
00665 }
00666 //_______________________________________________________________________
00667 void TMVA::RuleFit::FillVisHistCorr(const Rule * rule, std::vector<TH2F *> & hlist)
00668 {
00669    // help routine to MakeVisHists() - fills for all correlation plots
00670    if (rule==0) return;
00671    Double_t ruleimp  = rule->GetImportance();
00672    if (!(ruleimp>0)) return;
00673    if (ruleimp<fRuleEnsemble.GetImportanceCut()) return;
00674    //
00675    Int_t nhists = hlist.size();
00676    Int_t nvar   = fMethodBase->GetNvar();
00677    Int_t ncorr  = (nvar*(nvar+1)/2)-nvar;
00678    if (nhists!=ncorr) Log() << kERROR << "BUG TRAP: number of corr hists is not correct! ncorr = "
00679                             << ncorr << " nvar = " << nvar << " nhists = " << nhists << Endl;
00680    //
00681    std::vector< std::pair<Int_t,Int_t> > vindex;
00682    TString hstr, var1, var2;
00683    Int_t iv1=0,iv2=0;
00684    // not a nice way to do a check...
00685    for (Int_t ih=0; ih<nhists; ih++) {
00686       hstr = hlist[ih]->GetName();
00687       if (GetCorrVars( hstr, var1, var2 )) {
00688          iv1 = fMethodBase->DataInfo().FindVarIndex( var1 );
00689          iv2 = fMethodBase->DataInfo().FindVarIndex( var2 );
00690          vindex.push_back( std::pair<Int_t,Int_t>(iv2,iv1) ); // pair X, Y
00691       } 
00692       else {
00693          Log() << kERROR << "BUG TRAP: should not be here - failed getting var1 and var2" << Endl;
00694       }
00695    }
00696    //
00697    for (Int_t ih=0; ih<nhists; ih++) {
00698       if ( (rule->ContainsVariable(vindex[ih].first)) ||
00699            (rule->ContainsVariable(vindex[ih].second)) ) {
00700          FillCorr(hlist[ih],rule,vindex[ih].first,vindex[ih].second);
00701       }
00702    }
00703 }
00704 //_______________________________________________________________________
00705 Bool_t TMVA::RuleFit::GetCorrVars(TString & title, TString & var1, TString & var2)
00706 {
00707    // get first and second variables from title
00708    var1="";
00709    var2="";
00710    if(!title.BeginsWith("scat_")) return kFALSE;
00711 
00712    TString titleCopy = title(5,title.Length());
00713    if(titleCopy.Index("_RF2D")>=0) titleCopy.Remove(titleCopy.Index("_RF2D"));
00714 
00715    Int_t splitPos = titleCopy.Index("_vs_");
00716    if(splitPos>=0) { // there is a _vs_ in the string
00717       var1 = titleCopy(0,splitPos);
00718       var2 = titleCopy(splitPos+4, titleCopy.Length());
00719       return kTRUE;
00720    } 
00721    else {
00722       var1 = titleCopy;
00723       return kFALSE;
00724    }
00725 }
00726 //_______________________________________________________________________
00727 void TMVA::RuleFit::MakeVisHists()
00728 {
00729    // this will create histograms visualizing the rule ensemble
00730 
00731    const TString directories[5] = { "InputVariables_Id",
00732                                     "InputVariables_Deco",
00733                                     "InputVariables_PCA",
00734                                     "InputVariables_Gauss",
00735                                     "InputVariables_Gauss_Deco" };
00736 
00737    const TString corrDirName = "CorrelationPlots";   
00738    
00739    TDirectory* rootDir   = Factory::RootBaseDir();
00740    TDirectory* varDir    = 0;
00741    TDirectory* corrDir   = 0;
00742 
00743    TDirectory* methodDir = fMethodBase->BaseDir();
00744    TString varDirName;
00745    //
00746    Bool_t done=(rootDir==0);
00747    Int_t type=0;
00748    if (done) {
00749       Log() << kWARNING << "No basedir - BUG??" << Endl;
00750       return;
00751    }
00752    while (!done) {
00753       varDir = (TDirectory*)rootDir->Get( directories[type] );
00754       type++;
00755       done = ((varDir!=0) || (type>4));
00756    }
00757    if (varDir==0) {
00758       Log() << kWARNING << "No input variable directory found - BUG?" << Endl;
00759       return;
00760    }
00761    corrDir = (TDirectory*)varDir->Get( corrDirName );
00762    if (corrDir==0) {
00763       Log() << kWARNING << "No correlation directory found" << Endl;
00764       Log() << kWARNING << "Check for other warnings related to correlation histograms" << Endl;
00765       return;
00766    }
00767    if (methodDir==0) {
00768       Log() << kWARNING << "No rulefit method directory found - BUG?" << Endl;
00769       return;
00770    }
00771 
00772    varDirName = varDir->GetName();
00773    varDir->cd();
00774    //
00775    // get correlation plot directory
00776    corrDir = (TDirectory *)varDir->Get(corrDirName);
00777    if (corrDir==0) {
00778       Log() << kWARNING << "No correlation directory found : " << corrDirName << Endl;
00779       return;
00780    }
00781 
00782    // how many plots are in the var directory?
00783    Int_t noPlots = ((varDir->GetListOfKeys())->GetEntries()) / 2;
00784    Log() << kDEBUG << "Got number of plots = " << noPlots << Endl;
00785  
00786    // loop over all objects in directory
00787    std::vector<TH2F *> h1Vector;
00788    std::vector<TH2F *> h2CorrVector;
00789    TIter next(varDir->GetListOfKeys());
00790    TKey *key;
00791    while ((key = (TKey*)next())) {
00792       // make sure, that we only look at histograms
00793       TClass *cl = gROOT->GetClass(key->GetClassName());
00794       if (!cl->InheritsFrom(TH1F::Class())) continue;
00795       TH1F *sig = (TH1F*)key->ReadObj();
00796       TString hname= sig->GetName();
00797       Log() << kDEBUG << "Got histogram : " << hname << Endl;
00798 
00799       // check for all signal histograms
00800       if (hname.Contains("__S")){ // found a new signal plot
00801          TString htitle = sig->GetTitle();
00802          htitle.ReplaceAll("signal","");
00803          TString newname = hname;
00804          newname.ReplaceAll("__Signal","__RF");
00805          newname.ReplaceAll("__S","__RF");
00806 
00807          methodDir->cd();
00808          TH2F *newhist = new TH2F(newname,htitle,sig->GetNbinsX(),sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
00809                                   1,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
00810          varDir->cd();
00811          h1Vector.push_back( newhist );
00812       }
00813    }
00814    //
00815    corrDir->cd();
00816    TString var1,var2;
00817    TIter nextCorr(corrDir->GetListOfKeys());
00818    while ((key = (TKey*)nextCorr())) {
00819       // make sure, that we only look at histograms
00820       TClass *cl = gROOT->GetClass(key->GetClassName());
00821       if (!cl->InheritsFrom(TH2F::Class())) continue;
00822       TH2F *sig = (TH2F*)key->ReadObj();
00823       TString hname= sig->GetName();
00824 
00825       // check for all signal histograms
00826       if ((hname.Contains("scat_")) && (hname.Contains("_Signal"))) {
00827          Log() << kDEBUG << "Got histogram (2D) : " << hname << Endl;
00828          TString htitle = sig->GetTitle();
00829          htitle.ReplaceAll("(Signal)","");
00830          TString newname = hname;
00831          newname.ReplaceAll("_Signal","_RF2D");
00832 
00833          methodDir->cd();
00834          const Int_t rebin=2;
00835          TH2F *newhist = new TH2F(newname,htitle,
00836                                   sig->GetNbinsX()/rebin,sig->GetXaxis()->GetXmin(),sig->GetXaxis()->GetXmax(),
00837                                   sig->GetNbinsY()/rebin,sig->GetYaxis()->GetXmin(),sig->GetYaxis()->GetXmax());
00838          if (GetCorrVars( newname, var1, var2 )) {
00839             Int_t iv1 = fMethodBase->DataInfo().FindVarIndex(var1);
00840             Int_t iv2 = fMethodBase->DataInfo().FindVarIndex(var2);
00841             if (iv1<0) {
00842                sig->GetYaxis()->SetTitle(var1);
00843             } 
00844             else {
00845                sig->GetYaxis()->SetTitle(fMethodBase->GetInputTitle(iv1));
00846             }
00847             if (iv2<0) {
00848                sig->GetXaxis()->SetTitle(var2);
00849             } 
00850             else {
00851                sig->GetXaxis()->SetTitle(fMethodBase->GetInputTitle(iv2));
00852             }
00853          }
00854          corrDir->cd();
00855          h2CorrVector.push_back( newhist );
00856       }
00857    }
00858 
00859 
00860    varDir->cd();
00861    // fill rules
00862    UInt_t nrules = fRuleEnsemble.GetNRules();
00863    const Rule *rule;
00864    for (UInt_t i=0; i<nrules; i++) {
00865       rule = fRuleEnsemble.GetRulesConst(i);
00866       FillVisHistCut(rule, h1Vector);
00867    }
00868    // fill linear terms and normalise hists
00869    FillVisHistCut(0, h1Vector);
00870    NormVisHists(h1Vector);
00871  
00872    //
00873    corrDir->cd();
00874    // fill rules
00875    for (UInt_t i=0; i<nrules; i++) {
00876       rule = fRuleEnsemble.GetRulesConst(i);
00877       FillVisHistCorr(rule, h2CorrVector);
00878    }
00879    NormVisHists(h2CorrVector);
00880 
00881    // write histograms to file   
00882    methodDir->cd();
00883    for (UInt_t i=0; i<h1Vector.size();     i++) h1Vector[i]->Write();
00884    for (UInt_t i=0; i<h2CorrVector.size(); i++) h2CorrVector[i]->Write();
00885 }
00886 
00887 //_______________________________________________________________________
00888 void TMVA::RuleFit::MakeDebugHists()
00889 {
00890    // this will create a histograms intended rather for debugging or for the curious user
00891 
00892    TDirectory* methodDir = fMethodBase->BaseDir();
00893    if (methodDir==0) {
00894       Log() << kWARNING << "<MakeDebugHists> No rulefit method directory found - bug?" << Endl;
00895       return;
00896    }
00897    //
00898    methodDir->cd();
00899    std::vector<Double_t> distances;
00900    std::vector<Double_t> fncuts;
00901    std::vector<Double_t> fnvars;
00902    const Rule *ruleA;
00903    const Rule *ruleB;
00904    Double_t dABmin=1000000.0;
00905    Double_t dABmax=-1.0;
00906    UInt_t nrules = fRuleEnsemble.GetNRules();
00907    for (UInt_t i=0; i<nrules; i++) {
00908       ruleA = fRuleEnsemble.GetRulesConst(i);
00909       for (UInt_t j=i+1; j<nrules; j++) {
00910          ruleB = fRuleEnsemble.GetRulesConst(j);
00911          Double_t dAB = ruleA->RuleDist( *ruleB, kTRUE );
00912          if (dAB>-0.5) {
00913             UInt_t nc = ruleA->GetNcuts();
00914             UInt_t nv = ruleA->GetNumVarsUsed();
00915             distances.push_back(dAB);
00916             fncuts.push_back(static_cast<Double_t>(nc));
00917             fnvars.push_back(static_cast<Double_t>(nv));
00918             if (dAB<dABmin) dABmin=dAB;
00919             if (dAB>dABmax) dABmax=dAB;
00920          }
00921       }
00922    }
00923    //
00924    TH1F *histDist = new TH1F("RuleDist","Rule distances",100,dABmin,dABmax);
00925    TTree *distNtuple = new TTree("RuleDistNtuple","RuleDist ntuple");
00926    Double_t ntDist;
00927    Double_t ntNcuts;
00928    Double_t ntNvars;
00929    distNtuple->Branch("dist", &ntDist,  "dist/D");
00930    distNtuple->Branch("ncuts",&ntNcuts, "ncuts/D");
00931    distNtuple->Branch("nvars",&ntNvars, "nvars/D");
00932    //
00933    for (UInt_t i=0; i<distances.size(); i++) {
00934       histDist->Fill(distances[i]);
00935       ntDist  = distances[i];
00936       ntNcuts = fncuts[i];
00937       ntNvars = fnvars[i];
00938       distNtuple->Fill();
00939    }
00940    distNtuple->Write();
00941 }

Generated on Tue Jul 5 15:25:35 2011 for ROOT_528-00b_version by  doxygen 1.5.1