CCPruner.cxx

Go to the documentation of this file.
00001 /**********************************************************************************
00002  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00003  * Package: TMVA                                                                  *
00004  * Class  : CCPruner                                                              *
00005  * Web    : http://tmva.sourceforge.net                                           *
00006  *                                                                                *
00007  * Description: Cost Complexity Pruning                                           *
00008  * 
00009  * Author: Doug Schouten (dschoute@sfu.ca)
00010  *
00011  *                                                                                *
00012  * Copyright (c) 2007:                                                            *
00013  *      CERN, Switzerland                                                         *
00014  *      MPI-K Heidelberg, Germany                                                 *
00015  *      U. of Texas at Austin, USA                                                *
00016  *                                                                                *
00017  * Redistribution and use in source and binary forms, with or without             *
00018  * modification, are permitted according to the terms listed in LICENSE           *
00019  * (http://tmva.sourceforge.net/LICENSE)                                          *
00020  **********************************************************************************/
00021 
00022 #include "TMVA/CCPruner.h"
00023 #include "TMVA/SeparationBase.h"
00024 #include "TMVA/GiniIndex.h"
00025 #include "TMVA/MisClassificationError.h"
00026 #include "TMVA/CCTreeWrapper.h"
00027 
00028 #include <iostream>
00029 #include <fstream>
00030 #include <limits>
00031 #include <math.h>
00032 
00033  using namespace TMVA;
00034 
00035 //_______________________________________________________________________
00036 CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
00037                     SeparationBase* qualityIndex ) : 
00038    fAlpha(-1.0), 
00039    fValidationSample(validationSample),
00040    fValidationDataSet(NULL),
00041    fOptimalK(-1)
00042 {
00043    // constructor
00044    fTree = t_max;
00045    
00046    if(qualityIndex == NULL) {
00047       fOwnQIndex = true;
00048       fQualityIndex = new MisClassificationError();
00049    }
00050    else {
00051       fOwnQIndex = false;
00052       fQualityIndex = qualityIndex;
00053    }
00054    fDebug = kTRUE;
00055 }
00056 
00057 //_______________________________________________________________________
00058 CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
00059                     SeparationBase* qualityIndex ) : 
00060    fAlpha(-1.0), 
00061    fValidationSample(NULL),
00062    fValidationDataSet(validationSample),
00063    fOptimalK(-1)
00064 {
00065    // constructor
00066    fTree = t_max;
00067    
00068    if(qualityIndex == NULL) {
00069       fOwnQIndex = true;
00070       fQualityIndex = new MisClassificationError();
00071    }
00072    else {
00073       fOwnQIndex = false;
00074       fQualityIndex = qualityIndex;
00075    }
00076    fDebug = kTRUE;
00077 }
00078 
00079 
00080 //_______________________________________________________________________
00081 CCPruner::~CCPruner( )
00082 {
00083    if(fOwnQIndex) delete fQualityIndex;
00084    // destructor
00085 }
00086 
00087 //_______________________________________________________________________
00088 void CCPruner::Optimize( )
00089 {
00090    // determine the pruning sequence
00091 
00092    Bool_t HaveStopCondition = fAlpha > 0; // keep pruning the tree until reach the limit fAlpha
00093 
00094    // build a wrapper tree to perform work on
00095    CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);
00096 
00097    Int_t    k = 0;
00098    Double_t epsilon = std::numeric_limits<double>::epsilon();
00099    Double_t alpha = -1.0e10;
00100 
00101    ofstream outfile;
00102    if (fDebug) outfile.open("costcomplexity.log");
00103    if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
00104       if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
00105       delete dTWrapper;
00106       if (fDebug) outfile.close();
00107       return;
00108    }
00109 
00110    CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
00111    while(R->GetNLeafDaughters() > 1) { // prune upwards to the root node
00112       if(R->GetMinAlphaC() > alpha) 
00113          alpha = R->GetMinAlphaC(); // initialize alpha
00114 
00115       if(HaveStopCondition && alpha > fAlpha) break;
00116 
00117       CCTreeWrapper::CCTreeNode* t = R;
00118 
00119       while(t->GetMinAlphaC() < t->GetAlphaC()) { // descend to the weakest link
00120 
00121          if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon) 
00122             t = t->GetLeftDaughter();
00123          else
00124             t = t->GetRightDaughter();
00125       }
00126     
00127       if( t == R ) {
00128          if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
00129          break;
00130       }
00131 
00132       CCTreeWrapper::CCTreeNode* n = t;
00133 
00134       if (fDebug){
00135          outfile << "===========================" << std::endl
00136                  << "Pruning branch listed below" << std::endl
00137                  << "===========================" << std::endl;
00138          t->PrintRec( outfile );
00139        
00140       }
00141       if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
00142          break;
00143       }
00144       dTWrapper->PruneNode(t); // prune the branch rooted at node t
00145 
00146       while(t != R) { // go back up the (pruned) tree and recalculate R(T), alpha_c
00147          t = t->GetMother();
00148          t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + t->GetRightDaughter()->GetNLeafDaughters());
00149          t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() + 
00150                                       t->GetRightDaughter()->GetResubstitutionEstimate());
00151          t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate())/(t->GetNLeafDaughters() - 1));
00152          t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(), 
00153                                                            t->GetRightDaughter()->GetMinAlphaC())));
00154       }
00155       k += 1;
00156       if(!HaveStopCondition) {
00157          Double_t q;
00158          if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
00159          else q = dTWrapper->TestTreeQuality(fValidationSample);
00160          fQualityIndexList.push_back(q);
00161       }
00162       else { 
00163          fQualityIndexList.push_back(1.0);
00164       }
00165       fPruneSequence.push_back(n->GetDTNode());
00166       fPruneStrengthList.push_back(alpha);
00167    }
00168   
00169    Double_t qmax = -1.0e6;
00170    if(!HaveStopCondition) {
00171       for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
00172          if(fQualityIndexList[i] > qmax) {
00173             qmax = fQualityIndexList[i];
00174             k = i;
00175          }
00176       }
00177       fOptimalK = k;
00178    }
00179    else {
00180       fOptimalK = fPruneSequence.size() - 1;
00181    }
00182 
00183    if (fDebug){
00184       outfile << std::endl << "************ Summary **************"  << std::endl
00185               << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
00186      
00187       outfile << "Pruning strength parameters: [";
00188       for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++) 
00189          outfile << fPruneStrengthList[i] << ", ";
00190       outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
00191      
00192       outfile << "Misclassification rates: [";
00193       for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++) 
00194          outfile << fQualityIndexList[i] << ", ";
00195       outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]"  << std::endl;
00196      
00197       outfile << "Optimal index: " << fOptimalK+1 << std::endl;
00198       outfile.close();
00199    }
00200    delete dTWrapper;
00201 }
00202 
00203 //_______________________________________________________________________
00204 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
00205 {
00206    // return the prune strength (=alpha) corresponding to the prune sequence
00207    std::vector<DecisionTreeNode*> optimalSequence;
00208    if( fOptimalK >= 0 ) {
00209       for( Int_t i = 0; i < fOptimalK; i++ ) {
00210          optimalSequence.push_back(fPruneSequence[i]);
00211       }
00212    }
00213    return optimalSequence;
00214 }
00215 
00216 

Generated on Tue Jul 5 15:16:41 2011 for ROOT_528-00b_version by  doxygen 1.5.1