00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 #include "TMVA/CostComplexityPruneTool.h"
00028
00029 #include "TMVA/MsgLogger.h"
00030
00031 #include <fstream>
00032 #include <limits>
00033 #include <math.h>
00034
00035 using namespace TMVA;
00036
00037
00038
00039 CostComplexityPruneTool::CostComplexityPruneTool( SeparationBase* qualityIndex ) :
00040 IPruneTool(),
00041 fLogger(new MsgLogger("CostComplexityPruneTool") )
00042 {
00043
00044
00045 fOptimalK = -1;
00046
00047
00048
00049
00050
00051
00052
00053 fQualityIndexTool = qualityIndex;
00054
00055
00056 fLogger->SetMinType( kWARNING );
00057 }
00058
00059
00060 CostComplexityPruneTool::~CostComplexityPruneTool( ) {
00061
00062 if(fQualityIndexTool != NULL) delete fQualityIndexTool;
00063 }
00064
00065
00066 PruningInfo*
00067 CostComplexityPruneTool::CalculatePruningInfo( DecisionTree* dt,
00068 const IPruneTool::EventSample* validationSample,
00069 Bool_t isAutomatic )
00070 {
00071
00072
00073
00074
00075 if( isAutomatic ) SetAutomatic();
00076
00077 if( dt == NULL || (IsAutomatic() && validationSample == NULL) ) {
00078
00079
00080
00081 return NULL;
00082 }
00083
00084 Double_t Q = -1.0;
00085 Double_t W = 1.0;
00086
00087 if(IsAutomatic()) {
00088
00089 dt->ApplyValidationSample(validationSample);
00090 W = dt->GetSumWeights(validationSample);
00091
00092 Q = dt->TestPrunedTreeQuality();
00093
00094 Log() << kDEBUG << "Node purity limit is: " << dt->GetNodePurityLimit() << Endl;
00095 Log() << kDEBUG << "Sum of weights in pruning validation sample: " << W << Endl;
00096 Log() << kDEBUG << "Quality of tree prior to any pruning is " << Q/W << Endl;
00097 }
00098
00099
00100 try {
00101 InitTreePruningMetaData((DecisionTreeNode*)dt->GetRoot());
00102 }
00103 catch(std::string error) {
00104 Log() << kERROR << "Couldn't initialize the tree meta data because of error ("
00105 << error << ")" << Endl;
00106 return NULL;
00107 }
00108
00109 Log() << kDEBUG << "Automatic cost complexity pruning is " << (IsAutomatic()?"on":"off") << "." << Endl;
00110
00111 try {
00112 Optimize( dt, W );
00113 }
00114 catch(std::string error) {
00115 Log() << kERROR << "Error optimzing pruning sequence ("
00116 << error << ")" << Endl;
00117 return NULL;
00118 }
00119
00120 Log() << kDEBUG << "Index of pruning sequence to stop at: " << fOptimalK << Endl;
00121
00122 PruningInfo* info = new PruningInfo();
00123
00124
00125 if(fOptimalK < 0) {
00126
00127 info->PruneStrength = 0;
00128 info->QualityIndex = Q/W;
00129 info->PruneSequence.clear();
00130 Log() << kINFO << "no proper pruning could be calulated. Tree "
00131 << dt->GetTreeID() << " will not be pruned. Do not worry if this "
00132 << " happens for a few trees " << Endl;
00133 return info;
00134 }
00135 info->QualityIndex = fQualityIndexList[fOptimalK]/W;
00136 Log() << kDEBUG << " prune until k=" << fOptimalK << " with alpha="<<fPruneStrengthList[fOptimalK]<< Endl;
00137 for( Int_t i = 0; i < fOptimalK; i++ ){
00138 info->PruneSequence.push_back(fPruneSequence[i]);
00139 }
00140 if( IsAutomatic() ){
00141 info->PruneStrength = fPruneStrengthList[fOptimalK];
00142 }
00143 else {
00144 info->PruneStrength = fPruneStrength;
00145 }
00146
00147 return info;
00148 }
00149
00150
00151 void CostComplexityPruneTool::InitTreePruningMetaData( DecisionTreeNode* n ) {
00152
00153
00154
00155 if( n == NULL ) return;
00156
00157 Double_t s = n->GetNSigEvents();
00158 Double_t b = n->GetNBkgEvents();
00159
00160 if (fQualityIndexTool) n->SetNodeR( (s+b)*fQualityIndexTool->GetSeparationIndex(s,b));
00161 else n->SetNodeR( (s+b)*n->GetSeparationIndex() );
00162
00163 if(n->GetLeft() != NULL && n->GetRight() != NULL) {
00164 n->SetTerminal(kFALSE);
00165
00166 InitTreePruningMetaData(n->GetLeft());
00167 InitTreePruningMetaData(n->GetRight());
00168
00169 n->SetNTerminal( n->GetLeft()->GetNTerminal() +
00170 n->GetRight()->GetNTerminal());
00171
00172 n->SetSubTreeR( (n->GetLeft()->GetSubTreeR() +
00173 n->GetRight()->GetSubTreeR()));
00174
00175 n->SetAlpha( ((n->GetNodeR() - n->GetSubTreeR()) /
00176 (n->GetNTerminal() - 1)));
00177
00178
00179
00180 n->SetAlphaMinSubtree( std::min(n->GetAlpha(), std::min(n->GetLeft()->GetAlphaMinSubtree(),
00181 n->GetRight()->GetAlphaMinSubtree())));
00182 n->SetCC(n->GetAlpha());
00183
00184 } else {
00185 n->SetNTerminal( 1 ); n->SetTerminal( );
00186 if (fQualityIndexTool) n->SetSubTreeR(((s+b)*fQualityIndexTool->GetSeparationIndex(s,b)));
00187 else n->SetSubTreeR( (s+b)*n->GetSeparationIndex() );
00188 n->SetAlpha(std::numeric_limits<double>::infinity( ));
00189 n->SetAlphaMinSubtree(std::numeric_limits<double>::infinity( ));
00190 n->SetCC(n->GetAlpha());
00191 }
00192
00193
00194
00195
00196 }
00197
00198
00199
00200 void CostComplexityPruneTool::Optimize( DecisionTree* dt, Double_t weights ) {
00201
00202
00203
00204
00205
00206
00207
00208
00209 Int_t k = 1;
00210 Double_t alpha = -1.0e10;
00211 Double_t epsilon = std::numeric_limits<double>::epsilon();
00212
00213 fQualityIndexList.clear();
00214 fPruneSequence.clear();
00215 fPruneStrengthList.clear();
00216
00217 DecisionTreeNode* R = (DecisionTreeNode*)dt->GetRoot();
00218
00219 Double_t qmin = 0.0;
00220 if(IsAutomatic()){
00221
00222 qmin = dt->TestPrunedTreeQuality()/weights;
00223 }
00224
00225
00226
00227
00228
00229
00230
00231
00232
00233
00234 while(R->GetNTerminal() > 1) {
00235
00236
00237 alpha = TMath::Max(R->GetAlphaMinSubtree(), alpha);
00238
00239 if( R->GetAlphaMinSubtree() >= R->GetAlpha() ) {
00240 Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
00241 break;
00242 }
00243
00244
00245 DecisionTreeNode* t = R;
00246
00247
00248 while(t->GetAlphaMinSubtree() < t->GetAlpha()) {
00249
00250
00251
00252
00253 if(TMath::Abs(t->GetAlphaMinSubtree() - t->GetLeft()->GetAlphaMinSubtree()) < epsilon) {
00254 t = t->GetLeft();
00255 } else {
00256 t = t->GetRight();
00257 }
00258 }
00259
00260 if( t == R ) {
00261 Log() << kDEBUG << "\nCaught trying to prune the root node!" << Endl;
00262 break;
00263 }
00264
00265 DecisionTreeNode* n = t;
00266
00267
00268
00269
00270
00271
00272
00273
00274 dt->PruneNodeInPlace(t);
00275
00276 while(t != R) {
00277 t = t->GetParent();
00278 t->SetNTerminal(t->GetLeft()->GetNTerminal() + t->GetRight()->GetNTerminal());
00279 t->SetSubTreeR(t->GetLeft()->GetSubTreeR() + t->GetRight()->GetSubTreeR());
00280 t->SetAlpha((t->GetNodeR() - t->GetSubTreeR())/(t->GetNTerminal() - 1));
00281 t->SetAlphaMinSubtree(std::min(t->GetAlpha(), std::min(t->GetLeft()->GetAlphaMinSubtree(),
00282 t->GetRight()->GetAlphaMinSubtree())));
00283 t->SetCC(t->GetAlpha());
00284 }
00285 k += 1;
00286
00287 Log() << kDEBUG << "after this pruning step I would have " << R->GetNTerminal() << " remaining terminal nodes " << Endl;
00288
00289 if(IsAutomatic()) {
00290 Double_t q = dt->TestPrunedTreeQuality()/weights;
00291 fQualityIndexList.push_back(q);
00292 }
00293 else {
00294 fQualityIndexList.push_back(1.0);
00295 }
00296 fPruneSequence.push_back(n);
00297 fPruneStrengthList.push_back(alpha);
00298 }
00299
00300 if(fPruneSequence.size() == 0) {
00301 fOptimalK = -1;
00302 return;
00303 }
00304
00305 if(IsAutomatic()) {
00306 k = -1;
00307 for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
00308 if(fQualityIndexList[i] < qmin) {
00309 qmin = fQualityIndexList[i];
00310 k = i;
00311 }
00312 }
00313 fOptimalK = k;
00314 }
00315 else {
00316
00317 fOptimalK = int(fPruneStrength/100.0 * fPruneSequence.size() );
00318 Log() << kDEBUG << "SequenzeSize="<<fPruneSequence.size()
00319 << " fOptimalK " << fOptimalK << Endl;
00320
00321 }
00322
00323 Log() << kDEBUG << "\n************ Summary for Tree " << dt->GetTreeID() << " *******" << Endl
00324 << "Number of trees in the sequence: " << fPruneSequence.size() << Endl;
00325
00326 Log() << kDEBUG << "Pruning strength parameters: [";
00327 for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
00328 Log() << kDEBUG << fPruneStrengthList[i] << ", ";
00329 Log() << kDEBUG << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << Endl;
00330
00331 Log() << kDEBUG << "Misclassification rates: [";
00332 for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
00333 Log() << kDEBUG << fQualityIndexList[i] << ", ";
00334 Log() << kDEBUG << fQualityIndexList[fQualityIndexList.size()-1] << "]" << Endl;
00335
00336 Log() << kDEBUG << "Prune index: " << fOptimalK+1 << Endl;
00337
00338 }
00339