00001 /********************************************************************************** 00002 * Project: TMVA - a Root-integrated toolkit for multivariate data analysis * 00003 * Package: TMVA * 00004 * Class : TMVA::DecisionTree * 00005 * Web : http://tmva.sourceforge.net * 00006 * * 00007 * Description: * 00008 * Implementation of a Decision Tree * 00009 * * 00010 * Authors (alphabetical): * 00011 * Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland * 00012 * Helge Voss <Helge.Voss@cern.ch> - MPI-K Heidelberg, Germany * 00013 * Kai Voss <Kai.Voss@cern.ch> - U. of Victoria, Canada * 00014 * Doug Schouten <dschoute@sfu.ca> - Simon Fraser U., Canada * 00015 * * 00016 * Copyright (c) 2005: * 00017 * CERN, Switzerland * 00018 * U. of Victoria, Canada * 00019 * MPI-K Heidelberg, Germany * 00020 * * 00021 * Redistribution and use in source and binary forms, with or without * 00022 * modification, are permitted according to the terms listed in LICENSE * 00023 * (http://mva.sourceforge.net/license.txt) * 00024 * * 00025 **********************************************************************************/ 00026 00027 #include "TMVA/ExpectedErrorPruneTool.h" 00028 #include "TMVA/MsgLogger.h" 00029 00030 #include <map> 00031 00032 //_______________________________________________________________________ 00033 TMVA::ExpectedErrorPruneTool::ExpectedErrorPruneTool() : 00034 IPruneTool(), 00035 fDeltaPruneStrength(0), 00036 fNodePurityLimit(1), 00037 fLogger( new MsgLogger("ExpectedErrorPruneTool") ) 00038 {} 00039 00040 //_______________________________________________________________________ 00041 TMVA::ExpectedErrorPruneTool::~ExpectedErrorPruneTool() 00042 { 00043 delete fLogger; 00044 } 00045 00046 //_______________________________________________________________________ 00047 TMVA::PruningInfo* 00048 TMVA::ExpectedErrorPruneTool::CalculatePruningInfo( DecisionTree* dt, 00049 const IPruneTool::EventSample* validationSample, 00050 Bool_t isAutomatic ) 00051 { 00052 if( isAutomatic ) { 00053 //SetAutomatic( ); 00054 isAutomatic = kFALSE; 00055 Log() << kWARNING << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl; 00056 } 00057 if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) { 00058 // must have a valid decision tree to prune, and if the prune strength 00059 // is to be chosen automatically, must have a test sample from 00060 // which to calculate the quality of the pruned tree(s) 00061 return NULL; 00062 } 00063 fNodePurityLimit = dt->GetNodePurityLimit(); 00064 00065 if(IsAutomatic()) { 00066 Log() << kFATAL << "Sorry autmoatic pruning strength determination is not implemented yet" << Endl; 00067 /* 00068 dt->ApplyValidationSample(validationSample); 00069 Double_t weights = dt->GetSumWeights(validationSample); 00070 // set the initial prune strength 00071 fPruneStrength = 1.0e-3; //! FIXME somehow make this automatic, it depends strongly on the tree structure 00072 // better to set it too small, it will be increased automatically 00073 fDeltaPruneStrength = 1.0e-5; 00074 Int_t nnodes = this->CountNodes((DecisionTreeNode*)dt->GetRoot()); 00075 00076 Bool_t forceStop = kFALSE; 00077 Int_t errCount = 0, 00078 lastNodeCount = nnodes; 00079 00080 // find the maxiumum prune strength that still leaves the root's daughter nodes 00081 00082 while ( nnodes > 1 && !forceStop ) { 00083 fPruneStrength += fDeltaPruneStrength; 00084 Log() << "----------------------------------------------------" << Endl; 00085 FindListOfNodes((DecisionTreeNode*)dt->GetRoot()); 00086 for( UInt_t i = 0; i < fPruneSequence.size(); i++ ) 00087 fPruneSequence[i]->SetTerminal(); // prune all the nodes from the sequence 00088 // test the quality of the pruned tree 00089 Double_t quality = 1.0 - dt->TestPrunedTreeQuality()/weights; 00090 fQualityMap.insert(std::make_pair<const Double_t,Double_t>(quality,fPruneStrength)); 00091 00092 nnodes = CountNodes((DecisionTreeNode*)dt->GetRoot()); // count the number of nodes in the pruned tree 00093 00094 Log() << "Prune strength : " << fPruneStrength << Endl; 00095 Log() << "Had " << lastNodeCount << " nodes, now have " << nnodes << Endl; 00096 Log() << "Quality index is: " << quality << Endl; 00097 00098 if (lastNodeCount == nnodes) errCount++; 00099 else { 00100 errCount=0; // reset counter 00101 if ( nnodes < lastNodeCount / 2 ) { 00102 Log() << "Decreasing fDeltaPruneStrength to " << fDeltaPruneStrength/2.0 00103 << " because the number of nodes in the tree decreased by a factor of 2." << Endl; 00104 fDeltaPruneStrength /= 2.; 00105 } 00106 } 00107 lastNodeCount = nnodes; 00108 if (errCount > 20) { 00109 Log() << "Increasing fDeltaPruneStrength to " << fDeltaPruneStrength*2.0 00110 << " because the number of nodes in the tree didn't change." << Endl; 00111 fDeltaPruneStrength *= 2.0; 00112 } 00113 if (errCount > 40) { 00114 Log() << "Having difficulty determining the optimal prune strength, bailing out!" << Endl; 00115 forceStop = kTRUE; 00116 } 00117 // reset the tree for the next iteration 00118 for( UInt_t i = 0; i < fPruneSequence.size(); i++ ) 00119 fPruneSequence[i]->SetTerminal(false); 00120 fPruneSequence.clear(); 00121 } 00122 // from the set of pruned trees, find the one with the optimal quality index 00123 std::multimap<Double_t,Double_t>::reverse_iterator it = fQualityMap.rend(); ++it; 00124 fPruneStrength = it->second; 00125 FindListOfNodes((DecisionTreeNode*)dt->GetRoot()); 00126 00127 // adjust the step size for the next tree automatically 00128 fPruneStrength = 1.0e-3; 00129 fDeltaPruneStrength = (fPruneStrength - 1.0)/(Double_t)fQualityMap.size(); 00130 00131 return new PruningInfo(it->first, it->second, fPruneSequence); 00132 */ 00133 return NULL; 00134 } 00135 else { // no automatic pruning - just use the provided prune strength parameter 00136 FindListOfNodes( (DecisionTreeNode*)dt->GetRoot() ); 00137 return new PruningInfo( -1.0, fPruneStrength, fPruneSequence ); 00138 } 00139 } 00140 00141 //_______________________________________________________________________ 00142 void TMVA::ExpectedErrorPruneTool::FindListOfNodes( DecisionTreeNode* node ) 00143 { 00144 // recursive pruning of nodes using the Expected Error Pruning (EEP) 00145 TMVA::DecisionTreeNode *l = (TMVA::DecisionTreeNode*)node->GetLeft(); 00146 TMVA::DecisionTreeNode *r = (TMVA::DecisionTreeNode*)node->GetRight(); 00147 if (node->GetNodeType() == 0 && !(node->IsTerminal())) { // check all internal nodes 00148 this->FindListOfNodes(l); 00149 this->FindListOfNodes(r); 00150 if (this->GetSubTreeError(node) >= this->GetNodeError(node)) { 00151 //node->Print(Log()); 00152 fPruneSequence.push_back(node); 00153 } 00154 } 00155 } 00156 00157 //_______________________________________________________________________ 00158 Double_t TMVA::ExpectedErrorPruneTool::GetSubTreeError( DecisionTreeNode* node ) const 00159 { 00160 // calculate the expected statistical error on the subtree below "node" 00161 // which is used in the expected error pruning 00162 DecisionTreeNode *l = (DecisionTreeNode*)node->GetLeft(); 00163 DecisionTreeNode *r = (DecisionTreeNode*)node->GetRight(); 00164 if (node->GetNodeType() == 0 && !(node->IsTerminal())) { 00165 Double_t subTreeError = 00166 (l->GetNEvents() * this->GetSubTreeError(l) + 00167 r->GetNEvents() * this->GetSubTreeError(r)) / 00168 node->GetNEvents(); 00169 return subTreeError; 00170 } 00171 else { 00172 return this->GetNodeError(node); 00173 } 00174 } 00175 00176 //_______________________________________________________________________ 00177 Double_t TMVA::ExpectedErrorPruneTool::GetNodeError( DecisionTreeNode *node ) const 00178 { 00179 // Calculate an UPPER limit on the error made by the classification done 00180 // by this node. If the S/S+B of the node is f, then according to the 00181 // training sample, the error rate (fraction of misclassified events by 00182 // this node) is (1-f) 00183 // Now f has a statistical error according to the binomial distribution 00184 // hence the error on f can be estimated (same error as the binomial error 00185 // for efficency calculations ( sigma = sqrt(eff(1-eff)/nEvts ) ) 00186 00187 Double_t errorRate = 0; 00188 00189 Double_t nEvts = node->GetNEvents(); 00190 00191 // fraction of correctly classified events by this node: 00192 Double_t f = 0; 00193 if (node->GetPurity() > fNodePurityLimit) f = node->GetPurity(); 00194 else f = (1-node->GetPurity()); 00195 00196 Double_t df = TMath::Sqrt(f*(1-f)/nEvts); 00197 00198 errorRate = std::min(1.0,(1.0 - (f-fPruneStrength*df))); 00199 00200 // ------------------------------------------------------------------- 00201 // standard algorithm: 00202 // step 1: Estimate error on node using Laplace estimate 00203 // NodeError = (N - n + k -1 ) / (N + k) 00204 // N: number of events 00205 // k: number of event classes (2 for Signal, Background) 00206 // n: n event out of N belong to the class which has the majority in the node 00207 // step 2: Approximate "backed-up" error assuming we did not prune 00208 // (I'm never quite sure if they consider whole subtrees, or only 'next-to-leaf' 00209 // nodes)... 00210 // Subtree error = Sum_children ( P_i * NodeError_i) 00211 // P_i = probability of the node to make the decision, i.e. fraction of events in 00212 // leaf node ( N_leaf / N_parent) 00213 // step 3: 00214 00215 // Minimum Error Pruning (MEP) accordig to Niblett/Bratko 00216 //# of correctly classified events by this node: 00217 //Double_t n=f*nEvts ; 00218 //Double_t p_apriori = 0.5, m=100; 00219 //errorRate = (nEvts - n + (1-p_apriori) * m ) / (nEvts + m); 00220 00221 // Pessimistic error Pruing (proposed by Quinlan (error estimat with continuity approximation) 00222 //# of correctly classified events by this node: 00223 //Double_t n=f*nEvts ; 00224 //errorRate = (nEvts - n + 0.5) / nEvts ; 00225 00226 //const Double Z=.65; 00227 //# of correctly classified events by this node: 00228 //Double_t n=f*nEvts ; 00229 //errorRate = (f + Z*Z/(2*nEvts ) + Z*sqrt(f/nEvts - f*f/nEvts + Z*Z/4/nEvts /nEvts ) ) / (1 + Z*Z/nEvts ); 00230 //errorRate = (n + Z*Z/2 + Z*sqrt(n - n*n/nEvts + Z*Z/4) )/ (nEvts + Z*Z); 00231 //errorRate = 1 - errorRate; 00232 // ------------------------------------------------------------------- 00233 00234 return errorRate; 00235 } 00236 00237