CostComplexityPruneTool.cxx

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00003  * Package: TMVA                                                                  *
00004  * Class  : TMVA::DecisionTree                                                    *
00005  * Web    : http://tmva.sourceforge.net                                           *
00006  *                                                                                *
00007  * Description:                                                                   *
00008  *      Implementation of a Decision Tree                                         *
00009  *                                                                                *
00010  * Authors (alphabetical):                                                        *
00011  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00012  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00013  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00014  *      Doug Schouten   <dschoute@sfu.ca>        - Simon Fraser U., Canada        *
00015  *                                                                                *
00016  * Copyright (c) 2005:                                                            *
00017  *      CERN, Switzerland                                                         *
00018  *      U. of Victoria, Canada                                                    *
00019  *      MPI-K Heidelberg, Germany                                                 *
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  * (http://mva.sourceforge.net/license.txt)                                       *
00024  *                                                                                *
00025  **********************************************************************************/
00026 
00027 #include "TMVA/CostComplexityPruneTool.h"
00028 
00029 #include "TMVA/MsgLogger.h"
00030 
00031 #include <fstream>
00032 #include <limits>
00033 #include <math.h>
00034 
00035 using namespace TMVA;
00036 
00037 
00038 //_______________________________________________________________________
00039 CostComplexityPruneTool::CostComplexityPruneTool( SeparationBase* qualityIndex ) : 
00040    IPruneTool(),
00041    fLogger(new MsgLogger("CostComplexityPruneTool") )
00042 {
00043    // the constructor for the cost complexity prunig
00044 
00045    fOptimalK = -1;
00046 
00047    // !! changed from Dougs code. Now use the QualityIndex stored already
00048    // in the nodes when no "new" QualityIndex calculator is given. Like this
00049    // I can easily implement the Regression. For Regression, the pruning uses the
00050    // same sepearation index as in the tree building, hence doesn't need to re-calculate
00051    // (which would need more info than simply "s" and "b")
00052    
00053    fQualityIndexTool = qualityIndex;
00054 
00055    //fLogger->SetMinType( kDEBUG );
00056    fLogger->SetMinType( kWARNING );
00057 }
00058 
00059 //_______________________________________________________________________
00060 CostComplexityPruneTool::~CostComplexityPruneTool( ) {
00061    // the destructor for the cost complexity prunig
00062    if(fQualityIndexTool != NULL) delete fQualityIndexTool;
00063 }
00064 
00065 //_______________________________________________________________________
00066 PruningInfo*
00067 CostComplexityPruneTool::CalculatePruningInfo( DecisionTree* dt,
00068                                                const IPruneTool::EventSample* validationSample,
00069                                                Bool_t isAutomatic )
00070 {
00071 
00072    // the routine that basically "steers" the pruning process. Call the calculation of
00073    // the pruning sequence, the tree quality and alike..
00074    
00075    if( isAutomatic ) SetAutomatic();
00076 
00077    if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
00078       // must have a valid decision tree to prune, and if the prune strength
00079       // is to be chosen automatically, must have a test sample from
00080       // which to calculate the quality of the pruned tree(s)
00081       return NULL;
00082    }
00083 
00084    Double_t Q = -1.0;
00085    Double_t W = 1.0;
00086 
00087    if(IsAutomatic()) {
00088       // run the pruning validation sample through the unpruned tree
00089       dt->ApplyValidationSample(validationSample);
00090       W = dt->GetSumWeights(validationSample); // get the sum of weights in the pruning validation sample
00091       // calculate the quality of the tree in the unpruned case
00092       Q = dt->TestPrunedTreeQuality();
00093 
00094       Log() << kDEBUG << "Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
00095       Log() << kDEBUG << "Sum of weights in pruning validation sample: " << W << Endl;
00096       Log() << kDEBUG << "Quality of tree prior to any pruning is " << Q/W << Endl;
00097    }
00098 
00099    // store the cost complexity metadata for the decision tree at each node
00100    try {
00101       InitTreePruningMetaData((DecisionTreeNode*)dt->GetRoot());
00102    }
00103    catch(std::string error) {
00104       Log() << kERROR << "Couldn't initialize the tree meta data because of error ("
00105               << error << ")" << Endl;
00106       return NULL;
00107    }
00108 
00109    Log() << kDEBUG << "Automatic cost complexity pruning is " << (IsAutomatic()?"on":"off") << "." << Endl;
00110 
00111    try {
00112       Optimize( dt, W );  // run the cost complexity pruning algorithm
00113    }
00114    catch(std::string error) {
00115       Log() << kERROR << "Error optimzing pruning sequence ("
00116               << error << ")" << Endl;
00117       return NULL;
00118    }
00119 
00120    Log() << kDEBUG << "Index of pruning sequence to stop at: " << fOptimalK << Endl;
00121 
00122    PruningInfo* info = new PruningInfo();
00123 
00124 
00125    if(fOptimalK < 0) {
00126       // no pruning necessary, or wasn't able to compute a sequence
00127       info->PruneStrength = 0;
00128       info->QualityIndex = Q/W;
00129       info->PruneSequence.clear();
00130       Log() << kINFO << "no proper pruning could be calulated. Tree "   
00131             <<  dt->GetTreeID() << " will not be pruned. Do not worry if this " 
00132             << " happens for a few trees " << Endl;
00133       return info;
00134    }
00135    info->QualityIndex = fQualityIndexList[fOptimalK]/W;
00136    Log() << kDEBUG << " prune until k=" << fOptimalK << " with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
00137    for( Int_t i = 0; i < fOptimalK; i++ ){
00138       info->PruneSequence.push_back(fPruneSequence[i]);
00139    }
00140    if( IsAutomatic() ){
00141       info->PruneStrength = fPruneStrengthList[fOptimalK];
00142    }
00143    else {
00144       info->PruneStrength = fPruneStrength;
00145    }
00146 
00147    return info;
00148 }
00149 
00150 //_______________________________________________________________________
00151 void CostComplexityPruneTool::InitTreePruningMetaData( DecisionTreeNode* n ) {
00152    // initialise "meta data" for the pruning, like the "costcomplexity", the
00153    // critical alpha, the minimal alpha down the tree, etc...  for each node!!
00154 
00155    if( n == NULL ) return;
00156 
00157    Double_t s = n->GetNSigEvents();
00158    Double_t b = n->GetNBkgEvents();
00159    // set R(t) = N_events*Gini(t) or MisclassificationError(t), etc.
00160    if (fQualityIndexTool) n->SetNodeR( (s+b)*fQualityIndexTool->GetSeparationIndex(s,b));
00161    else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
00162 
00163    if(n->GetLeft() != NULL && n->GetRight() != NULL) { // n is an interior (non-leaf) node
00164       n->SetTerminal(kFALSE);
00165       // traverse the tree
00166       InitTreePruningMetaData(n->GetLeft());
00167       InitTreePruningMetaData(n->GetRight());
00168       // set |~T_t|
00169       n->SetNTerminal( n->GetLeft()->GetNTerminal() +
00170                        n->GetRight()->GetNTerminal());
00171       // set R(T) = sum[n' in ~T]{ R(n') }
00172       n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
00173                        n->GetRight()->GetSubTreeR()));
00174       // set alpha_c, the alpha value at which it becomes advantageaus to prune at node n
00175       n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
00176                     (n->GetNTerminal() - 1)));
00177 
00178       // G(t) = min( alpha_c, G(l(n)), G(r(n)) )
00179       // the minimum alpha in subtree rooted at this node
00180       n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
00181                                                               n->GetRight()->GetAlphaMinSubtree())));
00182       n->SetCC(n->GetAlpha());
00183 
00184    } else { // n is a terminal node
00185       n->SetNTerminal( 1 ); n->SetTerminal( );
00186       if (fQualityIndexTool) n->SetSubTreeR(((s+b)*fQualityIndexTool->GetSeparationIndex(s,b)));
00187       else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
00188       n->SetAlpha(std::numeric_limits<double>::infinity( ));
00189       n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
00190       n->SetCC(n->GetAlpha());
00191    }
00192 
00193 //    DecisionTreeNode* R = (DecisionTreeNode*)mdt->GetRoot();
00194 //    Double_t x = R->GetAlphaMinSubtree();
00195 //    Log() << "alphaMin(Root) = " << x << Endl;
00196 }
00197 
00198 
00199 //_______________________________________________________________________
00200 void CostComplexityPruneTool::Optimize( DecisionTree* dt, Double_t weights ) {
00201    // after the critical alpha values (at which the corresponding nodes would
00202    // be pruned away) had been established in the "InitMetaData" we need now:
00203    // automatic pruning:
00204    //   find the value of "alpha" for which the test sample gives minimal error,
00205    //   on the tree with all nodes pruned that have alpha_critital < alpha,
00206    // fixed parameter pruning
00207    //
00208 
00209    Int_t k = 1;
00210    Double_t alpha   = -1.0e10;
00211    Double_t epsilon = std::numeric_limits<double>::epsilon();
00212 
00213    fQualityIndexList.clear();
00214    fPruneSequence.clear();
00215    fPruneStrengthList.clear();
00216 
00217    DecisionTreeNode* R = (DecisionTreeNode*)dt->GetRoot();
00218 
00219    Double_t qmin = 0.0;
00220    if(IsAutomatic()){
00221       // initialize the tree quality (actually at this stage, it is the quality of the yet unpruned tree
00222       qmin = dt->TestPrunedTreeQuality()/weights;
00223    }
00224 
00225    // now prune the tree in steps until it is gone. At each pruning step, the pruning 
00226    // takes place at the node that is regarded as the "weakest link".
00227    // for automatic pruning, at each step, we calculate the current quality of the 
00228    //     tree and in the end we will prune at the minimum of the tree quality   
00229    // for the fixed parameter pruing, the cut is simply set at a relative position
00230    //     in the sequence according to the "lenght" of the sequence of pruned trees.
00231    //     100: at the end (pruned until the root node would be the next pruning candidate
00232    //     50: in the middle of the sequence
00233    //     etc...
00234    while(R->GetNTerminal() > 1) { // prune upwards to the root node
00235 
00236       // initialize alpha
00237       alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
00238 
00239       if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
00240          Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
00241          break;
00242       }
00243 
00244 
00245       DecisionTreeNode* t = R;
00246 
00247       // descend to the weakest link
00248       while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
00249 //          std::cout << t->GetAlphaMinSubtree() << "  " << t->GetAlpha()<< "  "
00250 //                    << t->GetAlphaMinSubtree()- t->GetAlpha()<<  " t==R?" << int(t == R) << std::endl;
00251          //      while(  (t->GetAlphaMinSubtree() - t->GetAlpha()) < epsilon)  {
00252          //         if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree())/TMath::Abs(t->GetAlphaMinSubtree()) < epsilon) {
00253          if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree()) < epsilon) {
00254             t = t->GetLeft();
00255          } else {
00256             t = t->GetRight();
00257          }
00258       }
00259 
00260       if( t == R ) {
00261          Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
00262          break;
00263       }
00264 
00265       DecisionTreeNode* n = t;
00266 
00267 //       Log() << kDEBUG  << "alpha[" << k << "]: " << alpha << Endl;
00268 //       Log() << kDEBUG  << "===========================" << Endl
00269 //               << "Pruning branch listed below the node" << Endl;
00270 //       t->Print( Log() );
00271 //       Log() << kDEBUG << "===========================" << Endl;
00272 //       t->PrintRecPrune( Log() );
00273 
00274       dt->PruneNodeInPlace(t); // prune the branch rooted at node t
00275 
00276       while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
00277          t = t->GetParent();
00278          t->SetNTerminal(t->GetLeft()->GetNTerminal() + t->GetRight()->GetNTerminal());
00279          t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
00280          t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
00281          t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
00282                                                                 t->GetRight()->GetAlphaMinSubtree())));
00283          t->SetCC(t->GetAlpha());
00284       }
00285       k += 1;
00286    
00287       Log() << kDEBUG << "after this pruning step I would have " << R->GetNTerminal() << " remaining terminal nodes " << Endl;
00288 
00289       if(IsAutomatic()) {
00290          Double_t q = dt->TestPrunedTreeQuality()/weights;
00291          fQualityIndexList.push_back(q);
00292       }
00293       else {
00294          fQualityIndexList.push_back(1.0);
00295       }
00296       fPruneSequence.push_back(n);
00297       fPruneStrengthList.push_back(alpha);
00298    }
00299 
00300    if(fPruneSequence.size() == 0) {
00301       fOptimalK = -1;
00302       return;
00303    }
00304 
00305    if(IsAutomatic()) {
00306       k = -1;
00307       for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
00308          if(fQualityIndexList[i] < qmin) {
00309             qmin = fQualityIndexList[i];
00310             k = i;
00311          }
00312       }
00313       fOptimalK = k;
00314    }
00315    else {
00316       // regularize the prune strength relative to this tree
00317       fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
00318       Log() << kDEBUG << "SequenzeSize="<<fPruneSequence.size()
00319             << "  fOptimalK " << fOptimalK << Endl;
00320 
00321    }
00322 
00323    Log() << kDEBUG  << "\n************ Summary for Tree " << dt->GetTreeID() << " *******"  << Endl
00324          << "Number of trees in the sequence: " << fPruneSequence.size() << Endl;
00325 
00326    Log() << kDEBUG  << "Pruning strength parameters: [";
00327    for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
00328       Log() << kDEBUG << fPruneStrengthList[i] << ", ";
00329    Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << Endl;
00330 
00331    Log() << kDEBUG  << "Misclassification rates: [";
00332    for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
00333       Log() << kDEBUG  << fQualityIndexList[i] << ", ";
00334    Log() << kDEBUG  << fQualityIndexList[fQualityIndexList.size()-1] << "]"  << Endl;
00335 
00336    Log() << kDEBUG  << "Prune index: " << fOptimalK+1 << Endl;
00337 
00338 }
00339 

Generated on Tue Jul 5 15:16:42 2011 for ROOT_528-00b_version by  doxygen 1.5.1