00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022 #include "TMVA/CCPruner.h"
00023 #include "TMVA/SeparationBase.h"
00024 #include "TMVA/GiniIndex.h"
00025 #include "TMVA/MisClassificationError.h"
00026 #include "TMVA/CCTreeWrapper.h"
00027
00028 #include <iostream>
00029 #include <fstream>
00030 #include <limits>
00031 #include <math.h>
00032
00033 using namespace TMVA;
00034
00035
00036 CCPruner::CCPruner( DecisionTree* t_max, const EventList* validationSample,
00037 SeparationBase* qualityIndex ) :
00038 fAlpha(-1.0),
00039 fValidationSample(validationSample),
00040 fValidationDataSet(NULL),
00041 fOptimalK(-1)
00042 {
00043
00044 fTree = t_max;
00045
00046 if(qualityIndex == NULL) {
00047 fOwnQIndex = true;
00048 fQualityIndex = new MisClassificationError();
00049 }
00050 else {
00051 fOwnQIndex = false;
00052 fQualityIndex = qualityIndex;
00053 }
00054 fDebug = kTRUE;
00055 }
00056
00057
00058 CCPruner::CCPruner( DecisionTree* t_max, const DataSet* validationSample,
00059 SeparationBase* qualityIndex ) :
00060 fAlpha(-1.0),
00061 fValidationSample(NULL),
00062 fValidationDataSet(validationSample),
00063 fOptimalK(-1)
00064 {
00065
00066 fTree = t_max;
00067
00068 if(qualityIndex == NULL) {
00069 fOwnQIndex = true;
00070 fQualityIndex = new MisClassificationError();
00071 }
00072 else {
00073 fOwnQIndex = false;
00074 fQualityIndex = qualityIndex;
00075 }
00076 fDebug = kTRUE;
00077 }
00078
00079
00080
00081 CCPruner::~CCPruner( )
00082 {
00083 if(fOwnQIndex) delete fQualityIndex;
00084
00085 }
00086
00087
00088 void CCPruner::Optimize( )
00089 {
00090
00091
00092 Bool_t HaveStopCondition = fAlpha > 0;
00093
00094
00095 CCTreeWrapper* dTWrapper = new CCTreeWrapper(fTree, fQualityIndex);
00096
00097 Int_t k = 0;
00098 Double_t epsilon = std::numeric_limits<double>::epsilon();
00099 Double_t alpha = -1.0e10;
00100
00101 ofstream outfile;
00102 if (fDebug) outfile.open("costcomplexity.log");
00103 if(!HaveStopCondition && (fValidationSample == NULL && fValidationDataSet == NULL) ) {
00104 if (fDebug) outfile << "ERROR: no validation sample, so cannot optimize pruning!" << std::endl;
00105 delete dTWrapper;
00106 if (fDebug) outfile.close();
00107 return;
00108 }
00109
00110 CCTreeWrapper::CCTreeNode* R = dTWrapper->GetRoot();
00111 while(R->GetNLeafDaughters() > 1) {
00112 if(R->GetMinAlphaC() > alpha)
00113 alpha = R->GetMinAlphaC();
00114
00115 if(HaveStopCondition && alpha > fAlpha) break;
00116
00117 CCTreeWrapper::CCTreeNode* t = R;
00118
00119 while(t->GetMinAlphaC() < t->GetAlphaC()) {
00120
00121 if(fabs(t->GetMinAlphaC() - t->GetLeftDaughter()->GetMinAlphaC())/fabs(t->GetMinAlphaC()) < epsilon)
00122 t = t->GetLeftDaughter();
00123 else
00124 t = t->GetRightDaughter();
00125 }
00126
00127 if( t == R ) {
00128 if (fDebug) outfile << std::endl << "Caught trying to prune the root node!" << std::endl;
00129 break;
00130 }
00131
00132 CCTreeWrapper::CCTreeNode* n = t;
00133
00134 if (fDebug){
00135 outfile << "===========================" << std::endl
00136 << "Pruning branch listed below" << std::endl
00137 << "===========================" << std::endl;
00138 t->PrintRec( outfile );
00139
00140 }
00141 if (!(t->GetLeftDaughter()) && !(t->GetRightDaughter()) ) {
00142 break;
00143 }
00144 dTWrapper->PruneNode(t);
00145
00146 while(t != R) {
00147 t = t->GetMother();
00148 t->SetNLeafDaughters(t->GetLeftDaughter()->GetNLeafDaughters() + t->GetRightDaughter()->GetNLeafDaughters());
00149 t->SetResubstitutionEstimate(t->GetLeftDaughter()->GetResubstitutionEstimate() +
00150 t->GetRightDaughter()->GetResubstitutionEstimate());
00151 t->SetAlphaC((t->GetNodeResubstitutionEstimate() - t->GetResubstitutionEstimate())/(t->GetNLeafDaughters() - 1));
00152 t->SetMinAlphaC(std::min(t->GetAlphaC(), std::min(t->GetLeftDaughter()->GetMinAlphaC(),
00153 t->GetRightDaughter()->GetMinAlphaC())));
00154 }
00155 k += 1;
00156 if(!HaveStopCondition) {
00157 Double_t q;
00158 if (fValidationDataSet != NULL) q = dTWrapper->TestTreeQuality(fValidationDataSet);
00159 else q = dTWrapper->TestTreeQuality(fValidationSample);
00160 fQualityIndexList.push_back(q);
00161 }
00162 else {
00163 fQualityIndexList.push_back(1.0);
00164 }
00165 fPruneSequence.push_back(n->GetDTNode());
00166 fPruneStrengthList.push_back(alpha);
00167 }
00168
00169 Double_t qmax = -1.0e6;
00170 if(!HaveStopCondition) {
00171 for(UInt_t i = 0; i < fQualityIndexList.size(); i++) {
00172 if(fQualityIndexList[i] > qmax) {
00173 qmax = fQualityIndexList[i];
00174 k = i;
00175 }
00176 }
00177 fOptimalK = k;
00178 }
00179 else {
00180 fOptimalK = fPruneSequence.size() - 1;
00181 }
00182
00183 if (fDebug){
00184 outfile << std::endl << "************ Summary **************" << std::endl
00185 << "Number of trees in the sequence: " << fPruneSequence.size() << std::endl;
00186
00187 outfile << "Pruning strength parameters: [";
00188 for(UInt_t i = 0; i < fPruneStrengthList.size()-1; i++)
00189 outfile << fPruneStrengthList[i] << ", ";
00190 outfile << fPruneStrengthList[fPruneStrengthList.size()-1] << "]" << std::endl;
00191
00192 outfile << "Misclassification rates: [";
00193 for(UInt_t i = 0; i < fQualityIndexList.size()-1; i++)
00194 outfile << fQualityIndexList[i] << ", ";
00195 outfile << fQualityIndexList[fQualityIndexList.size()-1] << "]" << std::endl;
00196
00197 outfile << "Optimal index: " << fOptimalK+1 << std::endl;
00198 outfile.close();
00199 }
00200 delete dTWrapper;
00201 }
00202
00203
00204 std::vector<DecisionTreeNode*> CCPruner::GetOptimalPruneSequence( ) const
00205 {
00206
00207 std::vector<DecisionTreeNode*> optimalSequence;
00208 if( fOptimalK >= 0 ) {
00209 for( Int_t i = 0; i < fOptimalK; i++ ) {
00210 optimalSequence.push_back(fPruneSequence[i]);
00211 }
00212 }
00213 return optimalSequence;
00214 }
00215
00216