DecisionTree.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DecisionTree.cxx 38085 2011-02-16 10:29:08Z evt $
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : TMVA::DecisionTree                                                    *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation of a Decision Tree                                         *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00016  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00017  *                                                                                *
00018  * Copyright (c) 2005:                                                            *
00019  *      CERN, Switzerland                                                         *
00020  *      U. of Victoria, Canada                                                    *
00021  *      MPI-K Heidelberg, Germany                                                 *
00022  *                                                                                *
00023  * Redistribution and use in source and binary forms, with or without             *
00024  * modification, are permitted according to the terms listed in LICENSE           *
00025  * (http://mva.sourceforge.net/license.txt)                                       *
00026  *                                                                                *
00027  **********************************************************************************/
00028 
00029 //_______________________________________________________________________
00030 //
00031 // Implementation of a Decision Tree
00032 //
00033 // In a decision tree successive decision nodes are used to categorize the
00034 // events out of the sample as either signal or background. Each node
00035 // uses only a single discriminating variable to decide if the event is
00036 // signal-like ("goes right") or background-like ("goes left"). This
00037 // forms a tree like structure with "baskets" at the end (leave nodes),
00038 // and an event is classified as either signal or background according to
00039 // whether the basket where it ends up has been classified signal or
00040 // background during the training. Training of a decision tree is the
00041 // process to define the "cut criteria" for each node. The training
00042 // starts with the root node. Here one takes the full training event
00043 // sample and selects the variable and corresponding cut value that gives
00044 // the best separation between signal and background at this stage. Using
00045 // this cut criterion, the sample is then divided into two subsamples, a
00046 // signal-like (right) and a background-like (left) sample. Two new nodes
00047 // are then created for each of the two sub-samples and they are
00048 // constructed using the same mechanism as described for the root
00049 // node. The devision is stopped once a certain node has reached either a
00050 // minimum number of events, or a minimum or maximum signal purity. These
00051 // leave nodes are then called "signal" or "background" if they contain
00052 // more signal respective background events from the training sample.
00053 //_______________________________________________________________________
00054 
00055 #include <iostream>
00056 #include <algorithm>
00057 #include <vector>
00058 #include <limits>
00059 #include <fstream>
00060 #include <algorithm>
00061 #include <cassert>
00062 
00063 #include "TRandom3.h"
00064 #include "TMath.h"
00065 #include "TMatrix.h"
00066 
00067 #include "TMVA/MsgLogger.h"
00068 #include "TMVA/DecisionTree.h"
00069 #include "TMVA/DecisionTreeNode.h"
00070 #include "TMVA/BinarySearchTree.h"
00071 
00072 #include "TMVA/Tools.h"
00073 
00074 #include "TMVA/GiniIndex.h"
00075 #include "TMVA/CrossEntropy.h"
00076 #include "TMVA/MisClassificationError.h"
00077 #include "TMVA/SdivSqrtSplusB.h"
00078 #include "TMVA/Event.h"
00079 #include "TMVA/BDTEventWrapper.h"
00080 #include "TMVA/IPruneTool.h"
00081 #include "TMVA/CostComplexityPruneTool.h"
00082 #include "TMVA/ExpectedErrorPruneTool.h"
00083 
00084 const Int_t TMVA::DecisionTree::fgRandomSeed = 0; // set nonzero for debugging and zero for random seeds
00085 
00086 using std::vector;
00087 
00088 ClassImp(TMVA::DecisionTree)
00089 
00090 //_______________________________________________________________________
00091 TMVA::DecisionTree::DecisionTree():
00092    BinaryTree(),
00093    fNvars          (0),
00094    fNCuts          (-1),
00095    fUseFisherCuts  (kFALSE),
00096    fMinLinCorrForFisher (1),
00097    fUseExclusiveVars (kTRUE),
00098    fSepType        (NULL),
00099    fRegType        (NULL),
00100    fMinSize        (0),
00101    fMinSepGain (0),
00102    fUseSearchTree(kFALSE),
00103    fPruneStrength(0),
00104    fPruneMethod    (kNoPruning),
00105    fNodePurityLimit(0.5),
00106    fRandomisedTree (kFALSE),
00107    fUseNvars       (0),
00108    fUsePoissonNvars(kFALSE),
00109    fMyTrandom (NULL), 
00110    fNNodesMax      (999999),
00111    fMaxDepth       (999999),
00112    fClass          (0),
00113    fTreeID         (0),
00114    fAnalysisType   (Types::kClassification)
00115 {
00116    // default constructor using the GiniIndex as separation criterion,
00117    // no restrictions on minium number of events in a leave note or the
00118    // separation gain in the node splitting
00119 }
00120 
00121 //_______________________________________________________________________
00122 TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType, Int_t minSize, Int_t nCuts, UInt_t cls,
00123                                   Bool_t randomisedTree, Int_t useNvars, Bool_t usePoissonNvars, UInt_t nNodesMax,
00124                                   UInt_t nMaxDepth, Int_t iSeed, Float_t purityLimit, Int_t treeID):
00125    BinaryTree(),
00126    fNvars          (0),
00127    fNCuts          (nCuts),
00128    fUseFisherCuts  (kFALSE),
00129    fMinLinCorrForFisher (1),
00130    fUseExclusiveVars (kTRUE),
00131    fSepType        (sepType),
00132    fRegType        (NULL),
00133    fMinSize        (minSize),
00134    fMinSepGain     (0),
00135    fUseSearchTree  (kFALSE),
00136    fPruneStrength  (0),
00137    fPruneMethod    (kNoPruning),
00138    fNodePurityLimit(purityLimit),
00139    fRandomisedTree (randomisedTree),
00140    fUseNvars       (useNvars),
00141    fUsePoissonNvars(usePoissonNvars),
00142    fMyTrandom      (new TRandom3(iSeed)),
00143    fNNodesMax      (nNodesMax),
00144    fMaxDepth       (nMaxDepth),
00145    fClass          (cls),
00146    fTreeID         (treeID)
00147 {
00148    // constructor specifying the separation type, the min number of
00149    // events in a no that is still subjected to further splitting, the
00150    // number of bins in the grid used in applying the cut for the node
00151    // splitting.
00152 
00153    if (sepType == NULL) { // it is interpreted as a regression tree, where
00154                           // currently the separation type (simple least square)
00155                           // cannot be chosen freely)
00156       fAnalysisType = Types::kRegression;
00157       fRegType = new RegressionVariance();
00158       if ( nCuts <=0 ) {
00159          fNCuts = 200;
00160          Log() << kWARNING << " You had choosen the training mode using optimal cuts, not\n"
00161                << " based on a grid of " << fNCuts << " by setting the option NCuts < 0\n"
00162                << " as this doesn't exist yet, I set it to " << fNCuts << " and use the grid"
00163                << Endl;
00164       }
00165    }else{
00166       fAnalysisType = Types::kClassification;
00167    }
00168 }
00169 
00170 //_______________________________________________________________________
00171 TMVA::DecisionTree::DecisionTree( const DecisionTree &d ):
00172    BinaryTree(),
00173    fNvars      (d.fNvars),
00174    fNCuts      (d.fNCuts),
00175    fUseFisherCuts  (d.fUseFisherCuts),
00176    fMinLinCorrForFisher (d.fMinLinCorrForFisher),
00177    fUseExclusiveVars (d.fUseExclusiveVars),
00178    fSepType    (d.fSepType),
00179    fRegType    (d.fRegType),
00180    fMinSize    (d.fMinSize),
00181    fMinSepGain (d.fMinSepGain),
00182    fUseSearchTree  (d.fUseSearchTree),
00183    fPruneStrength  (d.fPruneStrength),
00184    fPruneMethod    (d.fPruneMethod),
00185    fNodePurityLimit(d.fNodePurityLimit),
00186    fRandomisedTree (d.fRandomisedTree),
00187    fUseNvars       (d.fUseNvars),
00188    fUsePoissonNvars(d.fUsePoissonNvars),
00189    fMyTrandom      (new TRandom3(fgRandomSeed)),  // well, that means it's not an identical copy. But I only ever intend to really copy trees that are "outgrown" already. 
00190    fNNodesMax  (d.fNNodesMax),
00191    fMaxDepth   (d.fMaxDepth),
00192    fClass      (d.fClass),
00193    fTreeID     (d.fTreeID),
00194    fAnalysisType(d.fAnalysisType)
00195 {
00196    // copy constructor that creates a true copy, i.e. a completely independent tree
00197    // the node copy will recursively copy all the nodes
00198    this->SetRoot( new TMVA::DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
00199    this->SetParentTreeInNodes();
00200    fNNodes = d.fNNodes;
00201 }
00202 
00203 
00204 //_______________________________________________________________________
00205 TMVA::DecisionTree::~DecisionTree()
00206 {
00207    // destructor
00208 
00209    // destruction of the tree nodes done in the "base class" BinaryTree
00210 
00211    if (fMyTrandom) delete fMyTrandom;
00212 }
00213 
00214 //_______________________________________________________________________
00215 void TMVA::DecisionTree::SetParentTreeInNodes( Node *n )
00216 {
00217    // descend a tree to find all its leaf nodes, fill max depth reached in the
00218    // tree at the same time.
00219 
00220    if (n == NULL) { //default, start at the tree top, then descend recursively
00221       n = this->GetRoot();
00222       if (n == NULL) {
00223          Log() << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
00224          return ;
00225       }
00226    }
00227 
00228    if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
00229       Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00230       return;
00231    }  else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
00232       Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00233       return;
00234    }
00235    else {
00236       if (this->GetLeftDaughter(n) != NULL) {
00237          this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
00238       }
00239       if (this->GetRightDaughter(n) != NULL) {
00240          this->SetParentTreeInNodes( this->GetRightDaughter(n) );
00241       }
00242    }
00243    n->SetParentTree(this);
00244    if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
00245    return;
00246 }
00247 
00248 //_______________________________________________________________________
00249 TMVA::DecisionTree* TMVA::DecisionTree::CreateFromXML(void* node, UInt_t tmva_Version_Code ) {
00250    // re-create a new tree (decision tree or search tree) from XML
00251    std::string type("");
00252    gTools().ReadAttr(node,"type", type);
00253    DecisionTree* dt = new DecisionTree();
00254 
00255    dt->ReadXML( node, tmva_Version_Code );
00256    return dt;
00257 }
00258 
00259 
00260 //_______________________________________________________________________
00261 UInt_t TMVA::DecisionTree::BuildTree( const vector<TMVA::Event*> & eventSample,
00262                                       TMVA::DecisionTreeNode *node)
00263 {
00264    // building the decision tree by recursively calling the splitting of
00265    // one (root-) node into two daughter nodes (returns the number of nodes)
00266 
00267    Bool_t IsRootNode=kFALSE;
00268    if (node==NULL) {
00269       IsRootNode = kTRUE;
00270       //start with the root node
00271       node = new TMVA::DecisionTreeNode();
00272       fNNodes = 1;
00273       this->SetRoot(node);
00274       // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
00275       this->GetRoot()->SetPos('s');
00276       this->GetRoot()->SetDepth(0);
00277       this->GetRoot()->SetParentTree(this);
00278    }
00279 
00280    UInt_t nevents = eventSample.size();
00281 
00282    if (nevents > 0 ) {
00283       fNvars = eventSample[0]->GetNVariables();
00284       fVariableImportance.resize(fNvars);
00285    }
00286    else Log() << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
00287 
00288    Double_t s=0, b=0;
00289    Double_t suw=0, buw=0;
00290    Double_t target=0, target2=0;
00291    Float_t *xmin = new Float_t[fNvars];
00292    Float_t *xmax = new Float_t[fNvars];
00293    for (UInt_t ivar=0; ivar<fNvars; ivar++) {
00294       xmin[ivar]=xmax[ivar]=0;
00295    }
00296    for (UInt_t iev=0; iev<eventSample.size(); iev++) {
00297       const TMVA::Event* evt = eventSample[iev];
00298       const Double_t weight = evt->GetWeight();
00299       if (evt->GetClass() == fClass) {
00300          s += weight;
00301          suw += 1;
00302       }
00303       else {
00304          b += weight;
00305          buw += 1;
00306       }
00307       if ( DoRegression() ) {
00308          const Double_t tgt = evt->GetTarget(0);
00309          target +=weight*tgt;
00310          target2+=weight*tgt*tgt;
00311       }
00312 
00313       for (UInt_t ivar=0; ivar<fNvars; ivar++) {
00314          const Double_t val = evt->GetValue(ivar);
00315          if (iev==0) xmin[ivar]=xmax[ivar]=val;
00316          if (val < xmin[ivar]) xmin[ivar]=val;
00317          if (val > xmax[ivar]) xmax[ivar]=val;
00318       }
00319    }
00320 
00321    if (s+b < 0) {
00322       Log() << kWARNING << " One of the Decision Tree nodes has negative total number of signal or background events. "
00323             << "(Nsig="<<s<<" Nbkg="<<b<<" Probaby you use a Monte Carlo with negative weights. That should in principle "
00324             << "be fine as long as on average you end up with something positive. For this you have to make sure that the "
00325             << "minimul number of (unweighted) events demanded for a tree node (currently you use: nEventsMin="<<fMinSize
00326             << ", you can set this via the BDT option string when booking the classifier) is large enough to allow for "
00327             << "reasonable averaging!!!" << Endl
00328             << " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
00329             << "with negative weight in the training." << Endl;
00330       double nBkg=0.;
00331       for (UInt_t i=0; i<eventSample.size(); i++) {
00332          if (eventSample[i]->GetClass() != fClass) {
00333             nBkg += eventSample[i]->GetWeight();
00334             Log() << kINFO << "Event "<< i<< " has (original) weight: " <<  eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight() 
00335                   << " boostWeight: " << eventSample[i]->GetBoostWeight() << Endl;
00336          }
00337       }
00338       Log() << kINFO << " that gives in total: " << nBkg<<Endl;
00339    }
00340 
00341    node->SetNSigEvents(s);
00342    node->SetNBkgEvents(b);
00343    node->SetNSigEvents_unweighted(suw);
00344    node->SetNBkgEvents_unweighted(buw);
00345    node->SetPurity();
00346    if (node == this->GetRoot()) {
00347       node->SetNEvents(s+b);
00348       node->SetNEvents_unweighted(suw+buw);
00349    }
00350    for (UInt_t ivar=0; ivar<fNvars; ivar++) {
00351       node->SetSampleMin(ivar,xmin[ivar]);
00352       node->SetSampleMax(ivar,xmax[ivar]);
00353    }
00354    delete[] xmin;
00355    delete[] xmax;
00356 
00357    // I now demand the minimum number of events for both daughter nodes. Hence if the number
00358    // of events in the parent node is not at least two times as big, I don't even need to try
00359    // splitting
00360 
00361    if (eventSample.size() >= 2*fMinSize && fNNodes < fNNodesMax && node->GetDepth() < fMaxDepth 
00362        && ( ( s!=0 && b !=0 && !DoRegression()) || ( (s+b)!=0 && DoRegression()) ) ) {
00363       Double_t separationGain;
00364       if (fNCuts > 0){
00365          separationGain = this->TrainNodeFast(eventSample, node);
00366       } else {
00367          separationGain = this->TrainNodeFull(eventSample, node);
00368       }
00369       if (separationGain < std::numeric_limits<double>::epsilon()) { // we could not gain anything, e.g. all events are in one bin,
00370          /// if (separationGain < 0.00000001) { // we could not gain anything, e.g. all events are in one bin,
00371          // no cut can actually do anything to improve the node
00372          // hence, naturally, the current node is a leaf node
00373          if (DoRegression()) {
00374             node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
00375             node->SetResponse(target/(s+b));
00376             node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
00377          }
00378          else {
00379             node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
00380          }
00381          if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
00382          else node->SetNodeType(-1);
00383          if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
00384 
00385       } else {
00386 
00387          vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
00388          vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
00389 
00390          Double_t nRight=0, nLeft=0;
00391 
00392          for (UInt_t ie=0; ie< nevents ; ie++) {
00393             if (node->GoesRight(*eventSample[ie])) {
00394                rightSample.push_back(eventSample[ie]);
00395                nRight += eventSample[ie]->GetWeight();
00396             }
00397             else {
00398                leftSample.push_back(eventSample[ie]);
00399                nLeft += eventSample[ie]->GetWeight();
00400             }
00401          }
00402 
00403          // sanity check
00404          if (leftSample.size() == 0 || rightSample.size() == 0) {
00405             Log() << kFATAL << "<TrainNode> all events went to the same branch" << Endl
00406                   << "---                       Hence new node == old node ... check" << Endl
00407                   << "---                         left:" << leftSample.size()
00408                   << " right:" << rightSample.size() << Endl
00409                   << "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
00410                   << Endl;
00411          }
00412 
00413          // continue building daughter nodes for the left and the right eventsample
00414          TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
00415          fNNodes++;
00416          rightNode->SetNEvents(nRight);
00417          rightNode->SetNEvents_unweighted(rightSample.size());
00418 
00419          TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
00420 
00421          fNNodes++;
00422          leftNode->SetNEvents(nLeft);
00423          leftNode->SetNEvents_unweighted(leftSample.size());
00424 
00425          node->SetNodeType(0);
00426          node->SetLeft(leftNode);
00427          node->SetRight(rightNode);
00428 
00429          this->BuildTree(rightSample, rightNode);
00430          this->BuildTree(leftSample,  leftNode );
00431       }
00432    }
00433    else{ // it is a leaf node
00434       if (DoRegression()) {
00435          node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
00436          node->SetResponse(target/(s+b));
00437          node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
00438       }
00439       else {
00440          node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
00441       }
00442 
00443       if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
00444       else node->SetNodeType(-1);
00445 
00446       if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
00447    }
00448   
00449    //   if (IsRootNode) this->CleanTree();
00450    return fNNodes;
00451 }
00452 
00453 //_______________________________________________________________________
00454 void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample )
00455   
00456 {
00457    // fill the existing the decision tree structure by filling event
00458    // in from the top node and see where they happen to end up
00459    for (UInt_t i=0; i<eventSample.size(); i++) {
00460       this->FillEvent(*(eventSample[i]),NULL);
00461    }
00462 }
00463 
00464 //_______________________________________________________________________
00465 void TMVA::DecisionTree::FillEvent( TMVA::Event & event,  
00466                                     TMVA::DecisionTreeNode *node )
00467 {
00468    // fill the existing the decision tree structure by filling event
00469    // in from the top node and see where they happen to end up
00470   
00471    if (node == NULL) { // that's the start, take the Root node
00472       node = this->GetRoot();
00473    }
00474   
00475    node->IncrementNEvents( event.GetWeight() );
00476    node->IncrementNEvents_unweighted( );
00477   
00478    if (event.GetClass() == fClass) {
00479       node->IncrementNSigEvents( event.GetWeight() );
00480       node->IncrementNSigEvents_unweighted( );
00481    } 
00482    else {
00483       node->IncrementNBkgEvents( event.GetWeight() );
00484       node->IncrementNBkgEvents_unweighted( );
00485    }
00486    node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
00487                                                          node->GetNBkgEvents()));
00488   
00489    if (node->GetNodeType() == 0) { //intermediate node --> go down
00490       if (node->GoesRight(event))
00491          this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetRight())) ;
00492       else
00493          this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetLeft())) ;
00494    }
00495   
00496   
00497 }
00498 
00499 //_______________________________________________________________________
00500 void TMVA::DecisionTree::ClearTree()
00501 {
00502    // clear the tree nodes (their S/N, Nevents etc), just keep the structure of the tree
00503   
00504    if (this->GetRoot()!=NULL) this->GetRoot()->ClearNodeAndAllDaughters();
00505   
00506 }
00507 
00508 //_______________________________________________________________________
00509 UInt_t TMVA::DecisionTree::CleanTree( DecisionTreeNode *node )
00510 {
00511    // remove those last splits that result in two leaf nodes that
00512    // are both of the type (i.e. both signal or both background)
00513    // this of course is only a reasonable thing to do when you use
00514    // "YesOrNo" leafs, while it might loose s.th. if you use the
00515    // purity information in the nodes.
00516    // --> hence I don't call it automatically in the tree building
00517 
00518    if (node==NULL) {
00519       node = this->GetRoot();
00520    }
00521 
00522    DecisionTreeNode *l = node->GetLeft();
00523    DecisionTreeNode *r = node->GetRight();
00524 
00525    if (node->GetNodeType() == 0) {
00526       this->CleanTree(l);
00527       this->CleanTree(r);
00528       if (l->GetNodeType() * r->GetNodeType() > 0) {
00529 
00530          this->PruneNode(node);
00531       }
00532    }
00533    // update the number of nodes after the cleaning
00534    return this->CountNodes();
00535    
00536 }
00537 
00538 //_______________________________________________________________________
00539 Double_t TMVA::DecisionTree::PruneTree( vector<Event*>* validationSample )
00540 {
00541    // prune (get rid of internal nodes) the Decision tree to avoid overtraining
00542    // serveral different pruning methods can be applied as selected by the 
00543    // variable "fPruneMethod". 
00544   
00545    //   std::ofstream logfile("dt_pruning.log");
00546   
00547    IPruneTool* tool(NULL);
00548    PruningInfo* info(NULL);
00549 
00550    if( fPruneMethod == kNoPruning ) return 0.0;
00551 
00552    if      (fPruneMethod == kExpectedErrorPruning) 
00553       //      tool = new ExpectedErrorPruneTool(logfile);
00554       tool = new ExpectedErrorPruneTool();
00555    else if (fPruneMethod == kCostComplexityPruning) 
00556       {
00557          tool = new CostComplexityPruneTool();
00558       }
00559    else {
00560       Log() << kFATAL << "Selected pruning method not yet implemented "
00561             << Endl;
00562    }
00563    if(!tool) return 0.0;
00564 
00565    tool->SetPruneStrength(GetPruneStrength());
00566    if(tool->IsAutomatic()) {
00567       if(validationSample == NULL) 
00568          Log() << kFATAL << "Cannot automate the pruning algorithm without an "
00569                << "independent validation sample!" << Endl;
00570       if(validationSample->size() == 0) 
00571          Log() << kFATAL << "Cannot automate the pruning algorithm with "
00572                << "independent validation sample of ZERO events!" << Endl;
00573    }
00574 
00575    info = tool->CalculatePruningInfo(this,validationSample);
00576    if(!info) {
00577       delete tool;
00578       Log() << kFATAL << "Error pruning tree! Check prune.log for more information." 
00579             << Endl;
00580    }
00581    Double_t pruneStrength = info->PruneStrength;
00582 
00583    //   Log() << kDEBUG << "Optimal prune strength (alpha): " << pruneStrength
00584    //           << " has quality index " << info->QualityIndex << Endl;
00585    
00586 
00587    for (UInt_t i = 0; i < info->PruneSequence.size(); ++i) {
00588       
00589       PruneNode(info->PruneSequence[i]);
00590    }
00591    // update the number of nodes after the pruning
00592    this->CountNodes();
00593 
00594    delete tool;
00595    delete info;
00596 
00597    return pruneStrength;
00598 };
00599 
00600 
00601 //_______________________________________________________________________
00602 void TMVA::DecisionTree::ApplyValidationSample( const EventList* validationSample ) const
00603 {
00604    // run the validation sample through the (pruned) tree and fill in the nodes
00605    // the variables NSValidation and NBValidadtion (i.e. how many of the Signal
00606    // and Background events from the validation sample. This is then later used
00607    // when asking for the "tree quality" .. 
00608    GetRoot()->ResetValidationData();
00609    for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
00610       CheckEventWithPrunedTree(*(*validationSample)[ievt]);
00611    }
00612 }
00613 
00614 //_______________________________________________________________________
00615 Double_t TMVA::DecisionTree::TestPrunedTreeQuality( const DecisionTreeNode* n, Int_t mode ) const
00616 {
00617    // return the misclassification rate of a pruned tree
00618    // a "pruned tree" may have set the variable "IsTerminal" to "arbitrary" at
00619    // any node, hence this tree quality testing will stop there, hence test
00620    // the pruned tree (while the full tree is still in place for normal/later use)
00621    
00622    if (n == NULL) { // default, start at the tree top, then descend recursively
00623       n = this->GetRoot();
00624       if (n == NULL) {
00625          Log() << kFATAL << "TestPrunedTreeQuality: started with undefined ROOT node" <<Endl;
00626          return 0;
00627       }
00628    } 
00629 
00630    if( n->GetLeft() != NULL && n->GetRight() != NULL && !n->IsTerminal() ) {
00631       return (TestPrunedTreeQuality( n->GetLeft(), mode ) +
00632               TestPrunedTreeQuality( n->GetRight(), mode ));
00633    }
00634    else { // terminal leaf (in a pruned subtree of T_max at least)
00635       if (DoRegression()) {
00636          Double_t sumw = n->GetNSValidation() + n->GetNBValidation();
00637          return n->GetSumTarget2() - 2*n->GetSumTarget()*n->GetResponse() + sumw*n->GetResponse()*n->GetResponse();
00638       } 
00639       else {
00640          if (mode == 0) {
00641             if (n->GetPurity() > this->GetNodePurityLimit()) // this is a signal leaf, according to the training
00642                return n->GetNBValidation();
00643             else
00644                return n->GetNSValidation();
00645          }
00646          else if ( mode == 1 ) {
00647             // calculate the weighted error using the pruning validation sample
00648             return (n->GetPurity() * n->GetNBValidation() + (1.0 - n->GetPurity()) * n->GetNSValidation());
00649          }
00650          else {
00651             throw std::string("Unknown ValidationQualityMode");
00652          }
00653       }
00654    }
00655 }
00656 
00657 //_______________________________________________________________________
00658 void TMVA::DecisionTree::CheckEventWithPrunedTree( const Event& e ) const
00659 {
00660    // pass a single validation event throught a pruned decision tree
00661    // on the way down the tree, fill in all the "intermediate" information
00662    // that would normally be there from training.
00663 
00664    DecisionTreeNode* current =  this->GetRoot();
00665    if (current == NULL) {
00666       Log() << kFATAL << "CheckEventWithPrunedTree: started with undefined ROOT node" <<Endl;
00667    }
00668 
00669    while(current != NULL) {
00670       if(e.GetClass() == fClass)
00671          current->SetNSValidation(current->GetNSValidation() + e.GetWeight());
00672       else
00673          current->SetNBValidation(current->GetNBValidation() + e.GetWeight());
00674 
00675       if (e.GetNTargets() > 0) {
00676          current->AddToSumTarget(e.GetWeight()*e.GetTarget(0));
00677          current->AddToSumTarget2(e.GetWeight()*e.GetTarget(0)*e.GetTarget(0));
00678       }
00679 
00680       if (current->GetRight() == NULL || current->GetLeft() == NULL) {
00681          current = NULL;
00682       }
00683       else {
00684          if (current->GoesRight(e))
00685             current = (TMVA::DecisionTreeNode*)current->GetRight();
00686          else
00687             current = (TMVA::DecisionTreeNode*)current->GetLeft();
00688       }
00689    }
00690 }
00691 
00692 //_______________________________________________________________________
00693 Double_t TMVA::DecisionTree::GetSumWeights( const EventList* validationSample ) const
00694 {
00695    // calculate the normalization factor for a pruning validation sample
00696    Double_t sumWeights = 0.0;
00697    for( EventList::const_iterator it = validationSample->begin();
00698         it != validationSample->end(); ++it ) {
00699       sumWeights += (*it)->GetWeight();
00700    }
00701    return sumWeights;
00702 }
00703 
00704 
00705 
00706 //_______________________________________________________________________
00707 UInt_t TMVA::DecisionTree::CountLeafNodes( TMVA::Node *n )
00708 {
00709    // return the number of terminal nodes in the sub-tree below Node n
00710   
00711    if (n == NULL) { // default, start at the tree top, then descend recursively
00712       n =  this->GetRoot();
00713       if (n == NULL) {
00714          Log() << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
00715          return 0;
00716       }
00717    } 
00718   
00719    UInt_t countLeafs=0;
00720   
00721    if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
00722       countLeafs += 1;
00723    } 
00724    else { 
00725       if (this->GetLeftDaughter(n) != NULL) {
00726          countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
00727       }
00728       if (this->GetRightDaughter(n) != NULL) {
00729          countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
00730       }
00731    }
00732    return countLeafs;
00733 }
00734 
00735 //_______________________________________________________________________
00736 void TMVA::DecisionTree::DescendTree( Node* n )
00737 {
00738    // descend a tree to find all its leaf nodes
00739   
00740    if (n == NULL) { // default, start at the tree top, then descend recursively
00741       n =  this->GetRoot();
00742       if (n == NULL) {
00743          Log() << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
00744          return ;
00745       }
00746    } 
00747   
00748    if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
00749       // do nothing
00750    } 
00751    else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
00752       Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00753       return;
00754    }  
00755    else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
00756       Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00757       return;
00758    } 
00759    else { 
00760       if (this->GetLeftDaughter(n) != NULL) {
00761          this->DescendTree( this->GetLeftDaughter(n) );
00762       }
00763       if (this->GetRightDaughter(n) != NULL) {
00764          this->DescendTree( this->GetRightDaughter(n) );
00765       }
00766    }
00767 }
00768 
00769 //_______________________________________________________________________
00770 void TMVA::DecisionTree::PruneNode( DecisionTreeNode* node )
00771 {
00772    // prune away the subtree below the node 
00773    DecisionTreeNode *l = node->GetLeft();
00774    DecisionTreeNode *r = node->GetRight();
00775 
00776    node->SetRight(NULL);
00777    node->SetLeft(NULL);
00778    node->SetSelector(-1);
00779    node->SetSeparationGain(-1);
00780    if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
00781    else node->SetNodeType(-1);
00782    this->DeleteNode(l);
00783    this->DeleteNode(r);
00784    // update the stored number of nodes in the Tree
00785    this->CountNodes();
00786   
00787 }
00788 
00789 //_______________________________________________________________________
00790 void TMVA::DecisionTree::PruneNodeInPlace( DecisionTreeNode* node ) {
00791    // prune a node temporaily (without actually deleting its decendants
00792    // which allows testing the pruned tree quality for many different
00793    // pruning stages without "touching" the tree.
00794 
00795    if(node == NULL) return;
00796    node->SetNTerminal(1);
00797    node->SetSubTreeR( node->GetNodeR() );
00798    node->SetAlpha( std::numeric_limits<double>::infinity( ) );
00799    node->SetAlphaMinSubtree( std::numeric_limits<double>::infinity( ) );
00800    node->SetTerminal(kTRUE); // set the node to be terminal without deleting its descendants FIXME not needed
00801 }
00802 
00803 //_______________________________________________________________________
00804 TMVA::Node* TMVA::DecisionTree::GetNode( ULong_t sequence, UInt_t depth )
00805 {
00806    // retrieve node from the tree. Its position (up to a maximal tree depth of 64)
00807    // is coded as a sequence of left-right moves starting from the root, coded as
00808    // 0-1 bit patterns stored in the "long-integer"  (i.e. 0:left ; 1:right
00809   
00810    Node* current = this->GetRoot();
00811   
00812    for (UInt_t i =0;  i < depth; i++) {
00813       ULong_t tmp = 1 << i;
00814       if ( tmp & sequence) current = this->GetRightDaughter(current);
00815       else current = this->GetLeftDaughter(current);
00816    }
00817   
00818    return current;
00819 }
00820 
00821 
00822 //_______________________________________________________________________
00823 void TMVA::DecisionTree::GetRandomisedVariables(Bool_t *useVariable, UInt_t *mapVariable, UInt_t &useNvars){
00824   //
00825    for (UInt_t ivar=0; ivar<fNvars; ivar++) useVariable[ivar]=kFALSE;
00826    if (fUseNvars==0) { // no number specified ... choose s.th. which hopefully works well 
00827       // watch out, should never happen as it is initialised automatically in MethodBDT already!!!
00828       fUseNvars        =  UInt_t(TMath::Sqrt(fNvars)+0.6);
00829    }
00830    if (fUsePoissonNvars) useNvars=TMath::Min(fNvars,TMath::Max(UInt_t(1),(UInt_t) fMyTrandom->Poisson(fUseNvars)));
00831    else useNvars = fUseNvars;
00832 
00833    UInt_t nSelectedVars = 0;
00834    while (nSelectedVars < useNvars) {
00835       Double_t bla = fMyTrandom->Rndm()*fNvars;
00836       useVariable[Int_t (bla)] = kTRUE;
00837       nSelectedVars = 0;
00838       for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00839          if (useVariable[ivar] == kTRUE) { 
00840             mapVariable[nSelectedVars] = ivar;
00841             nSelectedVars++;
00842          }
00843       }
00844    }
00845    if (nSelectedVars != useNvars) { std::cout << "Bug in TrainNode - GetRandisedVariables()... sorry" << std::endl; std::exit(1);}
00846 }
00847 
00848 //_______________________________________________________________________
00849 Double_t TMVA::DecisionTree::TrainNodeFast( const vector<TMVA::Event*> & eventSample,
00850                                            TMVA::DecisionTreeNode *node )
00851 {
00852    // Decide how to split a node using one of the variables that gives
00853    // the best separation of signal/background. In order to do this, for each 
00854    // variable a scan of the different cut values in a grid (grid = fNCuts) is 
00855    // performed and the resulting separation gains are compared.
00856    // in addition to the individual variables, one can also ask for a fisher
00857    // discriminant being built out of (some) of the variables and used as a
00858    // possible multivariate split.
00859 
00860    Double_t separationGain = -1, sepTmp;
00861    Double_t cutValue=-999;
00862    Int_t mxVar= -1;
00863    Int_t cutIndex=-1;
00864    Bool_t cutType=kTRUE;
00865    Double_t  nTotS, nTotB;
00866    Int_t     nTotS_unWeighted, nTotB_unWeighted; 
00867    UInt_t nevents = eventSample.size();
00868 
00869 
00870    // the +1 comes from the fact that I treat later on the Fisher output as an 
00871    // additional possible variable.
00872    Bool_t *useVariable = new Bool_t[fNvars+1];   // for performance reasons instead of vector<Bool_t> useVariable(fNvars);
00873    UInt_t *mapVariable = new UInt_t[fNvars+1];    // map the subset of variables used in randomised trees to the original variable number (used in the Event() ) 
00874 
00875    std::vector<Double_t> fisherCoeff;
00876  
00877    if (fRandomisedTree) { // choose for each node splitting a random subset of variables to choose from
00878       UInt_t tmp=fUseNvars;
00879       GetRandomisedVariables(useVariable,mapVariable,tmp);
00880    } 
00881    else {
00882       for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00883          useVariable[ivar] = kTRUE;
00884          mapVariable[ivar] = ivar;
00885       }
00886    }
00887    useVariable[fNvars] = kFALSE; //by default fisher is not used..
00888 
00889    if (fUseFisherCuts) {
00890       useVariable[fNvars] = kTRUE; // that's were I store the "fisher MVA"
00891 
00892       //use for the Fisher discriminant ONLY those variables that show
00893       //some reasonable linear correlation in either Signal or Background
00894       Bool_t *useVarInFisher = new Bool_t[fNvars];   // for performance reasons instead of vector<Bool_t> useVariable(fNvars);
00895       UInt_t *mapVarInFisher = new UInt_t[fNvars];   // map the subset of variables used in randomised trees to the original variable number (used in the Event() ) 
00896       for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00897          useVarInFisher[ivar] = kFALSE;
00898          mapVarInFisher[ivar] = ivar;
00899       }
00900       
00901       std::vector<TMatrixDSym*>* covMatrices;
00902       covMatrices = gTools().CalcCovarianceMatrices( eventSample, 2 ); // currently for 2 classes only
00903       TMatrixD *ss = new TMatrixD(*(covMatrices->at(0)));
00904       TMatrixD *bb = new TMatrixD(*(covMatrices->at(1)));
00905       const TMatrixD *s = gTools().GetCorrelationMatrix(ss);
00906       const TMatrixD *b = gTools().GetCorrelationMatrix(bb);
00907       
00908       for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00909          for (UInt_t jvar=ivar+1; jvar < fNvars; jvar++) {
00910             if (  ( TMath::Abs( (*s)(ivar, jvar)) > fMinLinCorrForFisher) ||
00911                   ( TMath::Abs( (*b)(ivar, jvar)) > fMinLinCorrForFisher) ){
00912                useVarInFisher[ivar] = kTRUE;
00913                useVarInFisher[jvar] = kTRUE;
00914             }
00915          }
00916       }
00917       // now as you know which variables you want to use, count and map them:
00918       // such that you can use an array/matrix filled only with THOSE variables
00919       // that you used
00920       UInt_t nFisherVars = 0;
00921       for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00922          //now .. pick those variables that are used in the FIsher and are also
00923          //  part of the "allowed" variables in case of Randomized Trees)
00924          if (useVarInFisher[ivar] && useVariable[ivar]) {
00925             mapVarInFisher[nFisherVars++]=ivar;
00926             // now exclud the the variables used in the Fisher cuts, and don't 
00927             // use them anymore in the individual variable scan
00928             if (fUseExclusiveVars) useVariable[ivar] = kFALSE;
00929          }
00930       }
00931       
00932       
00933       fisherCoeff = this->GetFisherCoefficients(eventSample, nFisherVars, mapVarInFisher);
00934       delete [] useVarInFisher;
00935       delete [] mapVarInFisher;
00936    }
00937 
00938 
00939    const UInt_t nBins = fNCuts+1;
00940    UInt_t cNvars = fNvars;
00941    if (fUseFisherCuts) cNvars++;  // use the Fisher output simple as additional variable
00942 
00943    Double_t** nSelS = new Double_t* [cNvars];
00944    Double_t** nSelB = new Double_t* [cNvars];
00945    Double_t** nSelS_unWeighted = new Double_t* [cNvars];
00946    Double_t** nSelB_unWeighted = new Double_t* [cNvars];
00947    Double_t** target = new Double_t* [cNvars];
00948    Double_t** target2 = new Double_t* [cNvars];
00949    Double_t** cutValues = new Double_t* [cNvars];
00950 
00951    for (UInt_t i=0; i<cNvars; i++) {
00952       nSelS[i] = new Double_t [nBins];
00953       nSelB[i] = new Double_t [nBins];
00954       nSelS_unWeighted[i] = new Double_t [nBins];
00955       nSelB_unWeighted[i] = new Double_t [nBins];
00956       target[i] = new Double_t [nBins];
00957       target2[i] = new Double_t [nBins];
00958       cutValues[i] = new Double_t [nBins];
00959    }
00960 
00961    Double_t *xmin = new Double_t[cNvars]; 
00962    Double_t *xmax = new Double_t[cNvars];
00963 
00964    for (UInt_t ivar=0; ivar < cNvars; ivar++) {
00965       if (ivar < fNvars){
00966          xmin[ivar]=node->GetSampleMin(ivar);
00967          xmax[ivar]=node->GetSampleMax(ivar);
00968       } else { // the fisher variable
00969          xmin[ivar]=999;
00970          xmax[ivar]=-999;
00971          // too bad, for the moment I don't know how to do this without looping
00972          // once to get the "min max" and then AGAIN to fill the histogram
00973          for (UInt_t iev=0; iev<nevents; iev++) {
00974             // returns the Fisher value (no fixed range)
00975             Double_t result = fisherCoeff[fNvars]; // the fisher constant offset
00976             for (UInt_t jvar=0; jvar<fNvars; jvar++)
00977                result += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
00978             if (result > xmax[ivar]) xmax[ivar]=result;
00979             if (result < xmin[ivar]) xmin[ivar]=result;
00980          }
00981       }
00982       for (UInt_t ibin=0; ibin<nBins; ibin++) {
00983          nSelS[ivar][ibin]=0;
00984          nSelB[ivar][ibin]=0;
00985          nSelS_unWeighted[ivar][ibin]=0;
00986          nSelB_unWeighted[ivar][ibin]=0;
00987          target[ivar][ibin]=0;
00988          target2[ivar][ibin]=0;
00989          cutValues[ivar][ibin]=0;
00990       }
00991    }
00992 
00993    // fill the cut values for the scan:
00994    for (UInt_t ivar=0; ivar < cNvars; ivar++) {
00995 
00996       if ( useVariable[ivar] ) {
00997          
00998          //set the grid for the cut scan on the variables like this:
00999          // 
01000          //  |       |        |         |         |   ...      |        |  
01001          // xmin                                                       xmax
01002          //
01003          // cut      0        1         2         3   ...     fNCuts-1 (counting from zero)
01004          // bin  0       1         2         3       .....      nBins-1=fNCuts (counting from zero)
01005          // --> nBins = fNCuts+1
01006          // (NOTE, the cuts at xmin or xmax would just give the whole sample and
01007          //  hence can be safely omitted
01008          
01009          Double_t istepSize =( xmax[ivar] - xmin[ivar] ) / Double_t(nBins);
01010          for (Int_t icut=0; icut<fNCuts; icut++) {
01011             cutValues[ivar][icut]=xmin[ivar]+(Double_t(icut+1))*istepSize;
01012          }
01013       }
01014    }
01015   
01016    nTotS=0; nTotB=0;
01017    nTotS_unWeighted=0; nTotB_unWeighted=0;   
01018    for (UInt_t iev=0; iev<nevents; iev++) {
01019 
01020       Double_t eventWeight =  eventSample[iev]->GetWeight(); 
01021       if (eventSample[iev]->GetClass() == fClass) {
01022          nTotS+=eventWeight;
01023          nTotS_unWeighted++;
01024       }
01025       else {
01026          nTotB+=eventWeight;
01027          nTotB_unWeighted++;
01028       }
01029       
01030       Int_t iBin=-1;
01031       for (UInt_t ivar=0; ivar < cNvars; ivar++) {
01032          // now scan trough the cuts for each varable and find which one gives
01033          // the best separationGain at the current stage.
01034          if ( useVariable[ivar] ) {
01035             Double_t eventData;
01036             if (ivar < fNvars) eventData = eventSample[iev]->GetValue(ivar); 
01037             else { // the fisher variable
01038                eventData = fisherCoeff[fNvars];
01039                for (UInt_t jvar=0; jvar<fNvars; jvar++)
01040                   eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
01041                
01042             }
01043             // "maximum" is nbins-1 (the "-1" because we start counting from 0 !!
01044             iBin = TMath::Min(Int_t(nBins-1),TMath::Max(0,int (nBins*(eventData-xmin[ivar])/(xmax[ivar]-xmin[ivar]) ) ));
01045             if (eventSample[iev]->GetClass() == fClass) {
01046                nSelS[ivar][iBin]+=eventWeight;
01047                nSelS_unWeighted[ivar][iBin]++;
01048             } 
01049             else {
01050                nSelB[ivar][iBin]+=eventWeight;
01051                nSelB_unWeighted[ivar][iBin]++;
01052             }
01053             if (DoRegression()) {
01054                target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
01055                target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
01056             }
01057          }
01058       }
01059    }   
01060    // now turn the "histogram" into a cummulative distribution
01061    for (UInt_t ivar=0; ivar < cNvars; ivar++) {
01062       if (useVariable[ivar]) {
01063          for (UInt_t ibin=1; ibin < nBins; ibin++) {
01064             nSelS[ivar][ibin]+=nSelS[ivar][ibin-1];
01065             nSelS_unWeighted[ivar][ibin]+=nSelS_unWeighted[ivar][ibin-1];
01066             nSelB[ivar][ibin]+=nSelB[ivar][ibin-1];
01067             nSelB_unWeighted[ivar][ibin]+=nSelB_unWeighted[ivar][ibin-1];
01068             if (DoRegression()) {
01069                target[ivar][ibin] +=target[ivar][ibin-1] ;
01070                target2[ivar][ibin]+=target2[ivar][ibin-1];
01071             }
01072          }
01073          if (nSelS_unWeighted[ivar][nBins-1] +nSelB_unWeighted[ivar][nBins-1] != eventSample.size()) {
01074             Log() << kFATAL << "Helge, you have a bug ....nSelS_unw..+nSelB_unw..= "
01075                   << nSelS_unWeighted[ivar][nBins-1] +nSelB_unWeighted[ivar][nBins-1] 
01076                   << " while eventsample size = " << eventSample.size()
01077                   << Endl;
01078          }
01079          double lastBins=nSelS[ivar][nBins-1] +nSelB[ivar][nBins-1];
01080          double totalSum=nTotS+nTotB;
01081          if (TMath::Abs(lastBins-totalSum)/totalSum>0.01) {
01082             Log() << kFATAL << "Helge, you have another bug ....nSelS+nSelB= "
01083                   << lastBins
01084                   << " while total number of events = " << totalSum
01085                   << Endl;
01086          }
01087       }
01088    }
01089    // now select the optimal cuts for each varable and find which one gives
01090    // the best separationGain at the current stage
01091    for (UInt_t ivar=0; ivar < cNvars; ivar++) {
01092       if (useVariable[ivar]) {
01093          for (UInt_t iBin=0; iBin<nBins-1; iBin++) { // the last bin contains "all events" -->skip
01094             // the separationGain is defined as the various indices (Gini, CorssEntropy, e.t.c)
01095             // calculated by the "SamplePurities" fom the branches that would go to the
01096             // left or the right from this node if "these" cuts were used in the Node:
01097             // hereby: nSelS and nSelB would go to the right branch
01098             //        (nTotS - nSelS) + (nTotB - nSelB)  would go to the left branch;
01099 
01100             // only allow splits where both daughter nodes match the specified miniumum number
01101             // for this use the "unweighted" events, as you are interested in statistically 
01102             // significant splits, which is determined by the actual number of entries
01103             // for a node, rather than the sum of event weights.
01104 
01105             Double_t sl = nSelS_unWeighted[ivar][iBin];
01106             Double_t bl = nSelB_unWeighted[ivar][iBin];
01107             Double_t s  = nTotS_unWeighted;
01108             Double_t b  = nTotB_unWeighted;
01109             Double_t sr = s-sl;
01110             Double_t br = b-bl;
01111             if ( (sl+bl)>=fMinSize && (sr+br)>=fMinSize ) {
01112 
01113                if (DoRegression()) {
01114                   sepTmp = fRegType->GetSeparationGain(nSelS[ivar][iBin]+nSelB[ivar][iBin], 
01115                                                        target[ivar][iBin],target2[ivar][iBin],
01116                                                        nTotS+nTotB,
01117                                                        target[ivar][nBins-1],target2[ivar][nBins-1]);
01118                } else {
01119                   sepTmp = fSepType->GetSeparationGain(nSelS[ivar][iBin], nSelB[ivar][iBin], nTotS, nTotB);
01120                }
01121                if (separationGain < sepTmp) {
01122                   separationGain = sepTmp;
01123                   mxVar = ivar;
01124                   cutIndex = iBin;
01125                   if (cutIndex >= fNCuts) Log()<<kFATAL<<"ibin for cut " << iBin << Endl; 
01126                }
01127             }
01128          }
01129       }
01130    }
01131    
01132    if (DoRegression()) {
01133       node->SetSeparationIndex(fRegType->GetSeparationIndex(nTotS+nTotB,target[0][nBins-1],target2[0][nBins-1]));
01134       node->SetResponse(target[0][nBins-1]/(nTotS+nTotB));
01135       node->SetRMS(TMath::Sqrt(target2[0][nBins-1]/(nTotS+nTotB) - target[0][nBins-1]/(nTotS+nTotB)*target[0][nBins-1]/(nTotS+nTotB)));
01136    }
01137    else {
01138       node->SetSeparationIndex(fSepType->GetSeparationIndex(nTotS,nTotB));
01139    }
01140    if (mxVar >= 0) { 
01141       if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
01142       else cutType=kFALSE;      
01143       cutValue = cutValues[mxVar][cutIndex];
01144     
01145       node->SetSelector((UInt_t)mxVar);
01146       node->SetCutValue(cutValue);
01147       node->SetCutType(cutType);
01148       node->SetSeparationGain(separationGain);
01149       if (mxVar < (Int_t) fNvars){ // the fisher cut is actually not used in this node, hence don't need to store fisher components
01150          node->SetNFisherCoeff(0);
01151          fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
01152       }else{
01153          // allocate Fisher coefficients (use fNvars, and set the non-used ones to zero. Might
01154          // be even less storage space on average than storing also the mapping used otherwise
01155          // can always be changed relatively easy
01156          node->SetNFisherCoeff(fNvars+1);     
01157          for (UInt_t ivar=0; ivar<=fNvars; ivar++) {
01158             node->SetFisherCoeff(ivar,fisherCoeff[ivar]);
01159             // take 'fisher coeff. weighted estimate as variable importance, "Don't fill the offset coefficient though :) 
01160             if (ivar<fNvars){
01161                fVariableImportance[ivar] += fisherCoeff[ivar]*separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
01162             }
01163          }
01164       } 
01165    }
01166    else {
01167       separationGain = 0;
01168    }
01169   
01170 
01171    for (UInt_t i=0; i<cNvars; i++) {
01172       delete [] nSelS[i];
01173       delete [] nSelB[i];
01174       delete [] nSelS_unWeighted[i];
01175       delete [] nSelB_unWeighted[i];
01176       delete [] target[i];
01177       delete [] target2[i];
01178       delete [] cutValues[i];
01179    }
01180    delete [] nSelS;
01181    delete [] nSelB;
01182    delete [] nSelS_unWeighted;
01183    delete [] nSelB_unWeighted;
01184    delete [] target;
01185    delete [] target2;
01186    delete [] cutValues;
01187 
01188    delete [] xmin;
01189    delete [] xmax;
01190 
01191    delete [] useVariable;
01192    delete [] mapVariable;
01193 
01194    return separationGain;
01195 
01196 }
01197 
01198 
01199 
01200 //_______________________________________________________________________
01201 std::vector<Double_t>  TMVA::DecisionTree::GetFisherCoefficients(const EventList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher){ 
01202   // calculate the fisher coefficients for the event sample and the variables used
01203 
01204    std::vector<Double_t> fisherCoeff(fNvars+1);
01205 
01206    // initializaton of global matrices and vectors
01207    // average value of each variables for S, B, S+B
01208    TMatrixD* meanMatx = new TMatrixD( nFisherVars, 3 );
01209    
01210    // the covariance 'within class' and 'between class' matrices
01211    TMatrixD* betw = new TMatrixD( nFisherVars, nFisherVars );
01212    TMatrixD* with = new TMatrixD( nFisherVars, nFisherVars );
01213    TMatrixD* cov  = new TMatrixD( nFisherVars, nFisherVars );
01214 
01215    //
01216    // compute mean values of variables in each sample, and the overall means
01217    //
01218 
01219    // initialize internal sum-of-weights variables
01220    Double_t sumOfWeightsS = 0;
01221    Double_t sumOfWeightsB = 0;
01222    
01223    
01224    // init vectors
01225    Double_t* sumS = new Double_t[nFisherVars];
01226    Double_t* sumB = new Double_t[nFisherVars];
01227    for (UInt_t ivar=0; ivar<nFisherVars; ivar++) { sumS[ivar] = sumB[ivar] = 0; }   
01228 
01229    UInt_t nevents = eventSample.size();   
01230    // compute sample means
01231    for (UInt_t ievt=0; ievt<nevents; ievt++) {
01232       
01233       // read the Training Event into "event"
01234       const Event * ev = eventSample[ievt];
01235 
01236       // sum of weights
01237       Double_t weight = ev->GetWeight();
01238       if (ev->GetClass() == fClass) sumOfWeightsS += weight;
01239       else                          sumOfWeightsB += weight;
01240 
01241       Double_t* sum = ev->GetClass() == fClass ? sumS : sumB;
01242       for (UInt_t ivar=0; ivar<nFisherVars; ivar++) sum[ivar] += ev->GetValue( mapVarInFisher[ivar] )*weight;
01243    }
01244 
01245    for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {   
01246       (*meanMatx)( ivar, 2 ) = sumS[ivar];
01247       (*meanMatx)( ivar, 0 ) = sumS[ivar]/sumOfWeightsS;
01248       
01249       (*meanMatx)( ivar, 2 ) += sumB[ivar];
01250       (*meanMatx)( ivar, 1 ) = sumB[ivar]/sumOfWeightsB;
01251       
01252       // signal + background
01253       (*meanMatx)( ivar, 2 ) /= (sumOfWeightsS + sumOfWeightsB);
01254    }  
01255    delete [] sumS;
01256    delete [] sumB;
01257 
01258    // the matrix of covariance 'within class' reflects the dispersion of the
01259    // events relative to the center of gravity of their own class  
01260 
01261    // assert required
01262 
01263    assert( sumOfWeightsS > 0 && sumOfWeightsB > 0 );
01264 
01265    // product matrices (x-<x>)(y-<y>) where x;y are variables
01266 
01267    const Int_t nFisherVars2 = nFisherVars*nFisherVars;
01268    Double_t *sum2Sig  = new Double_t[nFisherVars2];
01269    Double_t *sum2Bgd  = new Double_t[nFisherVars2];
01270    Double_t *xval    = new Double_t[nFisherVars2];
01271    memset(sum2Sig,0,nFisherVars2*sizeof(Double_t));
01272    memset(sum2Bgd,0,nFisherVars2*sizeof(Double_t));
01273    
01274    // 'within class' covariance
01275    for (UInt_t ievt=0; ievt<nevents; ievt++) {
01276 
01277       // read the Training Event into "event"
01278       const Event* ev = eventSample[ievt];
01279 
01280       Double_t weight = ev->GetWeight(); // may ignore events with negative weights
01281 
01282       for (UInt_t x=0; x<nFisherVars; x++) xval[x] = ev->GetValue( mapVarInFisher[x] );
01283       Int_t k=0;
01284       for (UInt_t x=0; x<nFisherVars; x++) {
01285          for (UInt_t y=0; y<nFisherVars; y++) {            
01286             Double_t v = ( (xval[x] - (*meanMatx)(x, 0))*(xval[y] - (*meanMatx)(y, 0)) )*weight;
01287             if ( ev->GetClass() == fClass ) sum2Sig[k] += v;
01288             else                            sum2Bgd[k] += v;
01289             k++;
01290          }
01291       }
01292    }
01293    Int_t k=0;
01294    for (UInt_t x=0; x<nFisherVars; x++) {
01295       for (UInt_t y=0; y<nFisherVars; y++) {
01296          (*with)(x, y) = (sum2Sig[k] + sum2Bgd[k])/(sumOfWeightsS + sumOfWeightsB);
01297          k++;
01298       }
01299    }
01300 
01301    delete [] sum2Sig;
01302    delete [] sum2Bgd;
01303    delete [] xval;
01304 
01305 
01306    // the matrix of covariance 'between class' reflects the dispersion of the
01307    // events of a class relative to the global center of gravity of all the class
01308    // hence the separation between classes
01309 
01310 
01311    Double_t prodSig, prodBgd;
01312 
01313    for (UInt_t x=0; x<nFisherVars; x++) {
01314       for (UInt_t y=0; y<nFisherVars; y++) {
01315 
01316          prodSig = ( ((*meanMatx)(x, 0) - (*meanMatx)(x, 2))*
01317                      ((*meanMatx)(y, 0) - (*meanMatx)(y, 2)) );
01318          prodBgd = ( ((*meanMatx)(x, 1) - (*meanMatx)(x, 2))*
01319                      ((*meanMatx)(y, 1) - (*meanMatx)(y, 2)) );
01320 
01321          (*betw)(x, y) = (sumOfWeightsS*prodSig + sumOfWeightsB*prodBgd) / (sumOfWeightsS + sumOfWeightsB);
01322       }
01323    }
01324 
01325 
01326 
01327    // compute full covariance matrix from sum of within and between matrices
01328    for (UInt_t x=0; x<nFisherVars; x++) 
01329       for (UInt_t y=0; y<nFisherVars; y++) 
01330          (*cov)(x, y) = (*with)(x, y) + (*betw)(x, y);
01331         
01332    // Fisher = Sum { [coeff]*[variables] }
01333    //
01334    // let Xs be the array of the mean values of variables for signal evts
01335    // let Xb be the array of the mean values of variables for backgd evts
01336    // let InvWith be the inverse matrix of the 'within class' correlation matrix
01337    //
01338    // then the array of Fisher coefficients is 
01339    // [coeff] =TMath::Sqrt(fNsig*fNbgd)/fNevt*transpose{Xs-Xb}*InvWith
01340    TMatrixD* theMat = with; // Fishers original
01341    //   TMatrixD* theMat = cov; // Mahalanobis
01342       
01343    TMatrixD invCov( *theMat );
01344    if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
01345       Log() << kWARNING << "FisherCoeff matrix is almost singular with deterninant="
01346               << TMath::Abs(invCov.Determinant()) 
01347               << " did you use the variables that are linear combinations or highly correlated?" 
01348               << Endl;
01349    }
01350    if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
01351       Log() << kFATAL << "FisherCoeff matrix is singular with determinant="
01352               << TMath::Abs(invCov.Determinant())  
01353               << " did you use the variables that are linear combinations?" 
01354               << Endl;
01355    }
01356 
01357    invCov.Invert();
01358    
01359    // apply rescaling factor
01360    Double_t xfact = TMath::Sqrt( sumOfWeightsS*sumOfWeightsB ) / (sumOfWeightsS + sumOfWeightsB);
01361 
01362    // compute difference of mean values
01363    std::vector<Double_t> diffMeans( nFisherVars );
01364 
01365    for (UInt_t ivar=0; ivar<=fNvars; ivar++) fisherCoeff[ivar] = 0;
01366    for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
01367       for (UInt_t jvar=0; jvar<nFisherVars; jvar++) {
01368          Double_t d = (*meanMatx)(jvar, 0) - (*meanMatx)(jvar, 1);
01369          fisherCoeff[mapVarInFisher[ivar]] += invCov(ivar, jvar)*d;
01370       }    
01371     
01372       // rescale
01373       fisherCoeff[mapVarInFisher[ivar]] *= xfact;
01374    }
01375 
01376    // offset correction
01377    Double_t f0 = 0.0;
01378    for (UInt_t ivar=0; ivar<nFisherVars; ivar++){ 
01379       f0 += fisherCoeff[mapVarInFisher[ivar]]*((*meanMatx)(ivar, 0) + (*meanMatx)(ivar, 1));
01380    }
01381    f0 /= -2.0;  
01382 
01383    fisherCoeff[fNvars] = f0;  //as we start counting variables from "zero", I store the fisher offset at the END
01384    
01385    return fisherCoeff;
01386 }
01387 
01388 //_______________________________________________________________________
01389 Double_t TMVA::DecisionTree::TrainNodeFull( const vector<TMVA::Event*> & eventSample,
01390                                            TMVA::DecisionTreeNode *node )
01391 {
01392   
01393    // train a node by finding the single optimal cut for a single variable
01394    // that best separates signal and background (maximizes the separation gain)
01395   
01396    Double_t nTotS = 0.0, nTotB = 0.0;
01397    Int_t nTotS_unWeighted = 0, nTotB_unWeighted = 0;  
01398   
01399    vector<TMVA::BDTEventWrapper> bdtEventSample;
01400   
01401    // List of optimal cuts, separation gains, and cut types (removed background or signal) - one for each variable
01402    vector<Double_t> lCutValue( fNvars, 0.0 );
01403    vector<Double_t> lSepGain( fNvars, -1.0e6 );
01404    vector<Char_t> lCutType( fNvars ); // <----- bool is stored (for performance reasons, no vector<bool>  has been taken)
01405    lCutType.assign( fNvars, Char_t(kFALSE) );
01406   
01407    // Initialize (un)weighted counters for signal & background
01408    // Construct a list of event wrappers that point to the original data
01409    for( vector<TMVA::Event*>::const_iterator it = eventSample.begin(); it != eventSample.end(); ++it ) {
01410       if((*it)->GetClass() == fClass) { // signal or background event
01411          nTotS += (*it)->GetWeight();
01412          ++nTotS_unWeighted;
01413       }
01414       else {
01415          nTotB += (*it)->GetWeight();
01416          ++nTotB_unWeighted;
01417       }
01418       bdtEventSample.push_back(TMVA::BDTEventWrapper(*it));
01419    }
01420   
01421    vector<Char_t> useVariable(fNvars); // <----- bool is stored (for performance reasons, no vector<bool>  has been taken)
01422    useVariable.assign( fNvars, Char_t(kTRUE) );
01423 
01424    for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=Char_t(kFALSE);
01425    if (fRandomisedTree) { // choose for each node splitting a random subset of variables to choose from
01426       if (fUseNvars ==0 ) { // no number specified ... choose s.th. which hopefully works well 
01427          // watch out, should never happen as it is initialised automatically in MethodBDT already!!!
01428          fUseNvars        =  UInt_t(TMath::Sqrt(fNvars)+0.6);
01429       }
01430       Int_t nSelectedVars = 0;
01431       while (nSelectedVars < fUseNvars) {
01432          Double_t bla = fMyTrandom->Rndm()*fNvars;
01433          useVariable[Int_t (bla)] = Char_t(kTRUE);
01434          nSelectedVars = 0;
01435          for (UInt_t ivar=0; ivar < fNvars; ivar++) {
01436             if(useVariable[ivar] == Char_t(kTRUE)) nSelectedVars++;
01437          }
01438       }
01439    } 
01440    else {
01441       for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = Char_t(kTRUE);
01442    }
01443   
01444    for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) { // loop over all discriminating variables
01445       if(!useVariable[ivar]) continue; // only optimze with selected variables
01446       TMVA::BDTEventWrapper::SetVarIndex(ivar); // select the variable to sort by
01447       std::sort( bdtEventSample.begin(),bdtEventSample.end() ); // sort the event data 
01448     
01449       Double_t bkgWeightCtr = 0.0, sigWeightCtr = 0.0;
01450       vector<TMVA::BDTEventWrapper>::iterator it = bdtEventSample.begin(), it_end = bdtEventSample.end();
01451       for( ; it != it_end; ++it ) {
01452          if((**it)->GetClass() == fClass ) // specify signal or background event
01453             sigWeightCtr += (**it)->GetWeight();
01454          else 
01455             bkgWeightCtr += (**it)->GetWeight(); 
01456          // Store the accumulated signal (background) weights
01457          it->SetCumulativeWeight(false,bkgWeightCtr); 
01458          it->SetCumulativeWeight(true,sigWeightCtr);
01459       }
01460     
01461       const Double_t fPMin = 1.0e-6;
01462       Bool_t cutType = kFALSE;
01463       Long64_t index = 0;
01464       Double_t separationGain = -1.0, sepTmp = 0.0, cutValue = 0.0, dVal = 0.0, norm = 0.0;
01465       // Locate the optimal cut for this (ivar-th) variable
01466       for( it = bdtEventSample.begin(); it != it_end; ++it ) {
01467          if( index == 0 ) { ++index; continue; }
01468          if( *(*it) == NULL ) {
01469             Log() << kFATAL << "In TrainNodeFull(): have a null event! Where index=" 
01470                   << index << ", and parent node=" << node->GetParent() << Endl;
01471             break;
01472          }
01473          dVal = bdtEventSample[index].GetVal() - bdtEventSample[index-1].GetVal();
01474          norm = TMath::Abs(bdtEventSample[index].GetVal() + bdtEventSample[index-1].GetVal());
01475          // Only allow splits where both daughter nodes have the specified miniumum number of events
01476          // Splits are only sensible when the data are ordered (eg. don't split inside a sequence of 0's)
01477          if( index >= fMinSize && (nTotS_unWeighted + nTotB_unWeighted) - index >= fMinSize && TMath::Abs(dVal/(0.5*norm + 1)) > fPMin ) {
01478             sepTmp = fSepType->GetSeparationGain( it->GetCumulativeWeight(true), it->GetCumulativeWeight(false), sigWeightCtr, bkgWeightCtr );
01479             if( sepTmp > separationGain ) {
01480                separationGain = sepTmp;
01481                cutValue = it->GetVal() - 0.5*dVal; 
01482                Double_t nSelS = it->GetCumulativeWeight(true);
01483                Double_t nSelB = it->GetCumulativeWeight(false);
01484                // Indicate whether this cut is improving the node purity by removing background (enhancing signal)
01485                // or by removing signal (enhancing background)
01486                if( nSelS/sigWeightCtr > nSelB/bkgWeightCtr ) cutType = kTRUE; 
01487                else cutType = kFALSE; 
01488             }
01489          }
01490          ++index;
01491       }
01492       lCutType[ivar] = Char_t(cutType);
01493       lCutValue[ivar] = cutValue;
01494       lSepGain[ivar] = separationGain;
01495    }
01496   
01497    Double_t separationGain = -1.0;
01498    Int_t iVarIndex = -1;
01499    for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
01500       if( lSepGain[ivar] > separationGain ) {
01501          iVarIndex = ivar;
01502          separationGain = lSepGain[ivar];
01503       }
01504    }
01505   
01506    if(iVarIndex >= 0) {
01507       node->SetSelector(iVarIndex);
01508       node->SetCutValue(lCutValue[iVarIndex]);
01509       node->SetSeparationGain(lSepGain[iVarIndex]);
01510       node->SetCutType(lCutType[iVarIndex]);
01511     
01512       fVariableImportance[iVarIndex] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB);
01513    }
01514    else {
01515       separationGain = 0.0;
01516    }
01517   
01518    return separationGain;
01519 }
01520 
01521 //___________________________________________________________________________________
01522 TMVA::DecisionTreeNode* TMVA::DecisionTree::GetEventNode(const TMVA::Event & e) const
01523 {
01524    // get the pointer to the leaf node where a particular event ends up in...
01525    // (used in gradient boostinge)
01526 
01527    TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
01528    while(current->GetNodeType() == 0) { // intermediate node in a tree
01529       current = (current->GoesRight(e)) ?
01530          (TMVA::DecisionTreeNode*)current->GetRight() :
01531          (TMVA::DecisionTreeNode*)current->GetLeft();
01532    }
01533    return current;
01534 }
01535 
01536 //_______________________________________________________________________
01537 Double_t TMVA::DecisionTree::CheckEvent( const TMVA::Event & e, Bool_t UseYesNoLeaf ) const
01538 {
01539    // the event e is put into the decision tree (starting at the root node)
01540    // and the output is NodeType (signal) or (background) of the final node (basket)
01541    // in which the given events ends up. I.e. the result of the classification if
01542    // the event for this decision tree.
01543   
01544    TMVA::DecisionTreeNode *current = this->GetRoot();
01545    if (!current)
01546       Log() << kFATAL << "CheckEvent: started with undefined ROOT node" <<Endl;
01547 
01548    while (current->GetNodeType() == 0) { // intermediate node in a (pruned) tree
01549       current = (current->GoesRight(e)) ? 
01550          current->GetRight() :
01551          current->GetLeft();
01552       if (!current) {
01553          Log() << kFATAL << "DT::CheckEvent: inconsistent tree structure" <<Endl;
01554       }
01555 
01556    }
01557   
01558    if ( DoRegression() ){
01559       return current->GetResponse();
01560    } 
01561    else {
01562       if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
01563       else              return current->GetPurity();
01564    }
01565 }
01566 
01567 //_______________________________________________________________________
01568 Double_t  TMVA::DecisionTree::SamplePurity( vector<TMVA::Event*> eventSample )
01569 {
01570    // calculates the purity S/(S+B) of a given event sample
01571   
01572    Double_t sumsig=0, sumbkg=0, sumtot=0;
01573    for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
01574       if (eventSample[ievt]->GetClass() != fClass) sumbkg+=eventSample[ievt]->GetWeight();
01575       else sumsig+=eventSample[ievt]->GetWeight();
01576       sumtot+=eventSample[ievt]->GetWeight();
01577    }
01578    // sanity check
01579    if (sumtot!= (sumsig+sumbkg)){
01580       Log() << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
01581             << sumtot << " " << sumsig << " " << sumbkg << Endl;
01582    }
01583    if (sumtot>0) return sumsig/(sumsig + sumbkg);
01584    else return -1;
01585 }
01586 
01587 //_______________________________________________________________________
01588 vector< Double_t >  TMVA::DecisionTree::GetVariableImportance()
01589 {
01590    // Return the relative variable importance, normalized to all
01591    // variables together having the importance 1. The importance in
01592    // evaluated as the total separation-gain that this variable had in
01593    // the decision trees (weighted by the number of events)
01594   
01595    vector<Double_t> relativeImportance(fNvars);
01596    Double_t  sum=0;
01597    for (UInt_t i=0; i< fNvars; i++) {
01598       sum += fVariableImportance[i];
01599       relativeImportance[i] = fVariableImportance[i];
01600    } 
01601   
01602    for (UInt_t i=0; i< fNvars; i++) {
01603       if (sum > std::numeric_limits<double>::epsilon())
01604          relativeImportance[i] /= sum;
01605       else 
01606          relativeImportance[i] = 0;
01607    } 
01608    return relativeImportance;
01609 }
01610 
01611 //_______________________________________________________________________
01612 Double_t  TMVA::DecisionTree::GetVariableImportance( UInt_t ivar )
01613 {
01614    // returns the relative improtance of variable ivar
01615   
01616    vector<Double_t> relativeImportance = this->GetVariableImportance();
01617    if (ivar < fNvars) return relativeImportance[ivar];
01618    else {
01619       Log() << kFATAL << "<GetVariableImportance>" << Endl
01620             << "---                     ivar = " << ivar << " is out of range " << Endl;
01621    }
01622   
01623    return -1;
01624 }
01625 

Generated on Tue Jul 5 15:16:47 2011 for ROOT_528-00b_version by  doxygen 1.5.1