ExpectedErrorPruneTool.cxx

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00003  * Package: TMVA                                                                  *
00004  * Class  : TMVA::DecisionTree                                                    *
00005  * Web    : http://tmva.sourceforge.net                                           *
00006  *                                                                                *
00007  * Description:                                                                   *
00008  *      Implementation of a Decision Tree                                         *
00009  *                                                                                *
00010  * Authors (alphabetical):                                                        *
00011  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00012  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00013  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00014  *      Doug Schouten   <dschoute@sfu.ca>        - Simon Fraser U., Canada        *
00015  *                                                                                *
00016  * Copyright (c) 2005:                                                            *
00017  *      CERN, Switzerland                                                         *
00018  *      U. of Victoria, Canada                                                    *
00019  *      MPI-K Heidelberg, Germany                                                 *
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  * (http://mva.sourceforge.net/license.txt)                                       *
00024  *                                                                                *
00025  **********************************************************************************/
00026 
00027 #include "TMVA/ExpectedErrorPruneTool.h"
00028 #include "TMVA/MsgLogger.h"
00029 
00030 #include <map>
00031 
00032 //_______________________________________________________________________
00033 TMVA::ExpectedErrorPruneTool::ExpectedErrorPruneTool() :
00034    IPruneTool(),
00035    fDeltaPruneStrength(0),
00036    fNodePurityLimit(1),
00037    fLogger( new MsgLogger("ExpectedErrorPruneTool") )
00038 {}
00039 
00040 //_______________________________________________________________________
00041 TMVA::ExpectedErrorPruneTool::~ExpectedErrorPruneTool()
00042 {
00043    delete fLogger;
00044 }
00045 
00046 //_______________________________________________________________________
00047 TMVA::PruningInfo*
00048 TMVA::ExpectedErrorPruneTool::CalculatePruningInfo( DecisionTree* dt,
00049                                                     const IPruneTool::EventSample* validationSample,
00050                                                     Bool_t isAutomatic )
00051 {
00052    if( isAutomatic ) {
00053       //SetAutomatic( );
00054       isAutomatic = kFALSE;
00055       Log() << kWARNING << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl;
00056    }
00057    if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
00058       // must have a valid decision tree to prune, and if the prune strength
00059       // is to be chosen automatically, must have a test sample from
00060       // which to calculate the quality of the pruned tree(s)
00061       return NULL;
00062    }
00063    fNodePurityLimit = dt->GetNodePurityLimit();
00064 
00065    if(IsAutomatic()) { 
00066       Log() << kFATAL << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl;
00067       /*
00068       dt->ApplyValidationSample(validationSample);
00069       Double_t weights = dt->GetSumWeights(validationSample);
00070       // set the initial prune strength
00071       fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure
00072       // better to set it too small, it will be increased automatically
00073       fDeltaPruneStrength = 1.0e-5;
00074       Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot());
00075 
00076       Bool_t forceStop = kFALSE;
00077       Int_t errCount = 0,
00078          lastNodeCount = nnodes;
00079 
00080       // find the maxiumum prune strength that still leaves the root's daughter nodes
00081       
00082       while ( nnodes > 1 && !forceStop ) {
00083          fPruneStrength += fDeltaPruneStrength;
00084          Log() << "----------------------------------------------------" << Endl;
00085          FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
00086          for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
00087             fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence
00088          // test the quality of the pruned tree
00089          Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights;
00090          fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength));
00091 
00092          nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree
00093 
00094          Log() << "Prune strength : " << fPruneStrength << Endl;
00095          Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl;
00096          Log() << "Quality index is: " << quality << Endl;
00097 
00098          if (lastNodeCount == nnodes) errCount++;
00099          else {
00100             errCount=0; // reset counter
00101             if ( nnodes < lastNodeCount / 2 ) {
00102                Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0
00103                      << " because the number of nodes in the tree decreased by a factor of 2." << Endl;
00104                fDeltaPruneStrength /= 2.;
00105             }
00106          }
00107          lastNodeCount = nnodes;
00108          if (errCount > 20) {
00109             Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0
00110                   << " because the number of nodes in the tree didn't change." << Endl;
00111             fDeltaPruneStrength *= 2.0;
00112          }
00113          if (errCount > 40) {
00114             Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl;
00115             forceStop = kTRUE;
00116          }
00117          // reset the tree for the next iteration
00118          for( UInt_t i = 0; i < fPruneSequence.size(); i++ )
00119             fPruneSequence[i]->SetTerminal(false);
00120          fPruneSequence.clear();
00121       }
00122       // from the set of pruned trees, find the one with the optimal quality index
00123       std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it;
00124       fPruneStrength = it->second;
00125       FindListOfNodes((DecisionTreeNode*)dt->GetRoot());
00126 
00127       // adjust the step size for the next tree automatically
00128       fPruneStrength = 1.0e-3;
00129       fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size();
00130 
00131       return new PruningInfo(it->first, it->second, fPruneSequence);
00132       */
00133       return NULL;
00134    }
00135    else { // no automatic pruning - just use the provided prune strength parameter
00136       FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() );
00137       return new PruningInfo( -1.0, fPruneStrength, fPruneSequence );
00138    }
00139 }
00140 
00141 //_______________________________________________________________________
00142 void TMVA::ExpectedErrorPruneTool::FindListOfNodes( DecisionTreeNode* node ) 
00143 {
00144    // recursive pruning of nodes using the Expected Error Pruning (EEP)
00145    TMVA::DecisionTreeNode *l = (TMVA::DecisionTreeNode*)node->GetLeft();
00146    TMVA::DecisionTreeNode *r = (TMVA::DecisionTreeNode*)node->GetRight();
00147    if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes
00148       this->FindListOfNodes(l);
00149       this->FindListOfNodes(r);
00150       if (this->GetSubTreeError(node) >= this->GetNodeError(node)) {
00151          //node->Print(Log());
00152          fPruneSequence.push_back(node);
00153       }
00154    }
00155 }
00156 
00157 //_______________________________________________________________________
00158 Double_t TMVA::ExpectedErrorPruneTool::GetSubTreeError( DecisionTreeNode* node ) const 
00159 {
00160    // calculate the expected statistical error on the subtree below "node"
00161    // which is used in the expected error pruning
00162    DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft();
00163    DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight();
00164    if (node->GetNodeType() == 0 && !(node->IsTerminal())) {
00165       Double_t subTreeError =
00166          (l->GetNEvents() * this->GetSubTreeError(l) +
00167           r->GetNEvents() * this->GetSubTreeError(r)) /
00168          node->GetNEvents();
00169       return subTreeError;
00170    }
00171    else {
00172       return this->GetNodeError(node);
00173    }
00174 }
00175 
00176 //_______________________________________________________________________
00177 Double_t TMVA::ExpectedErrorPruneTool::GetNodeError( DecisionTreeNode *node ) const 
00178 {
00179    // Calculate an UPPER limit on the error made by the classification done
00180    // by this node. If the S/S+B of the node is f, then according to the
00181    // training sample, the error rate (fraction of misclassified events by
00182    // this node) is (1-f)
00183    // Now f has a statistical error according to the binomial distribution
00184    // hence the error on f can be estimated (same error as the binomial error
00185    // for efficency calculations ( sigma = sqrt(eff(1-eff)/nEvts ) )
00186 
00187    Double_t errorRate = 0;
00188 
00189    Double_t nEvts = node->GetNEvents();
00190 
00191    // fraction of correctly classified events by this node:
00192    Double_t f = 0;
00193    if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity();
00194    else  f = (1-node->GetPurity());
00195 
00196    Double_t df = TMath::Sqrt(f*(1-f)/nEvts);
00197 
00198    errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df)));
00199 
00200    // -------------------------------------------------------------------
00201    // standard algorithm:
00202    // step 1: Estimate error on node using Laplace estimate
00203    //         NodeError = (N - n + k -1 ) / (N + k)
00204    //   N: number of events
00205    //   k: number of event classes (2 for Signal, Background)
00206    //   n: n event out of N belong to the class which has the majority in the node
00207    // step 2: Approximate "backed-up" error assuming we did not prune
00208    //   (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf'
00209    //    nodes)...
00210    //   Subtree error = Sum_children ( P_i * NodeError_i)
00211    //    P_i = probability of the node to make the decision, i.e. fraction of events in
00212    //          leaf node ( N_leaf / N_parent)
00213    // step 3:
00214 
00215    // Minimum Error Pruning (MEP) accordig to Niblett/Bratko
00216    //# of correctly classified events by this node:
00217    //Double_t n=f*nEvts ;
00218    //Double_t p_apriori = 0.5, m=100;
00219    //errorRate = (nEvts  - n + (1-p_apriori) * m ) / (nEvts  + m);
00220 
00221    // Pessimistic error Pruing (proposed by Quinlan (error estimat with continuity approximation)
00222    //# of correctly classified events by this node:
00223    //Double_t n=f*nEvts ;
00224    //errorRate = (nEvts  - n + 0.5) / nEvts ;
00225 
00226    //const Double Z=.65;
00227    //# of correctly classified events by this node:
00228    //Double_t n=f*nEvts ;
00229    //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts  - f*f/nEvts  + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts );
00230    //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts  + Z*Z/4) )/ (nEvts  + Z*Z);
00231    //errorRate = 1 - errorRate;
00232    // -------------------------------------------------------------------
00233 
00234    return errorRate;
00235 }
00236 
00237 

Generated on Tue Jul 5 15:16:47 2011 for ROOT_528-00b_version by  doxygen 1.5.1