00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065
00066
00067
00068
00069
00070
00071
00072
00073
00074
00075
00076
00077
00078
00079
00080
00081
00082
00083
00084
00085
00086
00087
00088
00089
00090
00091 #include <algorithm>
00092 #include "Riostream.h"
00093 #include "TRandom3.h"
00094 #include "TMath.h"
00095 #include "TObjString.h"
00096
00097 #include "TMVA/ClassifierFactory.h"
00098 #include "TMVA/MethodDT.h"
00099 #include "TMVA/Tools.h"
00100 #include "TMVA/Timer.h"
00101 #include "TMVA/Ranking.h"
00102 #include "TMVA/SdivSqrtSplusB.h"
00103 #include "TMVA/BinarySearchTree.h"
00104 #include "TMVA/SeparationBase.h"
00105 #include "TMVA/GiniIndex.h"
00106 #include "TMVA/CrossEntropy.h"
00107 #include "TMVA/MisClassificationError.h"
00108 #include "TMVA/MethodBoost.h"
00109 #include "TMVA/CCPruner.h"
00110
00111 using std::vector;
00112
00113 REGISTER_METHOD(DT)
00114
00115 ClassImp(TMVA::MethodDT)
00116
00117
00118 TMVA::MethodDT::MethodDT( const TString& jobName,
00119 const TString& methodTitle,
00120 DataSetInfo& theData,
00121 const TString& theOption,
00122 TDirectory* theTargetDir ) :
00123 TMVA::MethodBase( jobName, Types::kDT, methodTitle, theData, theOption, theTargetDir )
00124 , fTree(0)
00125 , fNodeMinEvents(0)
00126 , fNCuts(0)
00127 , fUseYesNoLeaf(kFALSE)
00128 , fNodePurityLimit(0)
00129 , fNNodesMax(0)
00130 , fMaxDepth(0)
00131 , fErrorFraction(0)
00132 , fPruneStrength(0)
00133 , fPruneMethod(DecisionTree::kNoPruning)
00134 , fAutomatic(kFALSE)
00135 , fRandomisedTrees(kFALSE)
00136 , fUseNvars(0)
00137 , fPruneBeforeBoost(kFALSE)
00138 , fDeltaPruneStrength(0)
00139 {
00140
00141 }
00142
00143
00144 TMVA::MethodDT::MethodDT( DataSetInfo& dsi,
00145 const TString& theWeightFile,
00146 TDirectory* theTargetDir ) :
00147 TMVA::MethodBase( Types::kDT, dsi, theWeightFile, theTargetDir )
00148 , fTree(0)
00149 , fNodeMinEvents(0)
00150 , fNCuts(0)
00151 , fUseYesNoLeaf(kFALSE)
00152 , fNodePurityLimit(0)
00153 , fNNodesMax(0)
00154 , fMaxDepth(0)
00155 , fErrorFraction(0)
00156 , fPruneStrength(0)
00157 , fPruneMethod(DecisionTree::kNoPruning)
00158 , fAutomatic(kFALSE)
00159 , fRandomisedTrees(kFALSE)
00160 , fUseNvars(0)
00161 , fPruneBeforeBoost(kFALSE)
00162 , fDeltaPruneStrength(0)
00163 {
00164
00165 }
00166
00167
00168 Bool_t TMVA::MethodDT::HasAnalysisType( Types::EAnalysisType type, UInt_t numberClasses, UInt_t )
00169 {
00170
00171 if( type == Types::kClassification && numberClasses == 2 ) return kTRUE;
00172 return kFALSE;
00173 }
00174
00175
00176
00177 void TMVA::MethodDT::DeclareOptions()
00178 {
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200 DeclareOptionRef(fRandomisedTrees,"UseRandomisedTrees","Choose at each node splitting a random set of variables and *bagging*");
00201 DeclareOptionRef(fUseNvars,"UseNvars","Number of variables used if randomised Tree option is chosen");
00202 DeclareOptionRef(fUseYesNoLeaf=kTRUE, "UseYesNoLeaf",
00203 "Use Sig or Bkg node type or the ratio S/B as classification in the leaf node");
00204 DeclareOptionRef(fNodePurityLimit=0.5, "NodePurityLimit", "In boosting/pruning, nodes with purity > NodePurityLimit are signal; background otherwise.");
00205 DeclareOptionRef(fPruneBeforeBoost=kFALSE, "PruneBeforeBoost",
00206 "Whether to perform the prune process right after the training or after the boosting");
00207 DeclareOptionRef(fSepTypeS="GiniIndex", "SeparationType", "Separation criterion for node splitting");
00208 AddPreDefVal(TString("MisClassificationError"));
00209 AddPreDefVal(TString("GiniIndex"));
00210 AddPreDefVal(TString("CrossEntropy"));
00211 AddPreDefVal(TString("SDivSqrtSPlusB"));
00212 DeclareOptionRef(fNodeMinEvents, "nEventsMin", "Minimum number of events in a leaf node (default: max(20, N_train/(Nvar^2)/10) ) ");
00213 DeclareOptionRef(fNCuts, "nCuts", "Number of steps during node cut optimisation");
00214 DeclareOptionRef(fPruneStrength, "PruneStrength", "Pruning strength (negative value == automatic adjustment)");
00215 DeclareOptionRef(fPruneMethodS, "PruneMethod", "Pruning method: NoPruning (switched off), ExpectedError or CostComplexity");
00216
00217 AddPreDefVal(TString("NoPruning"));
00218 AddPreDefVal(TString("ExpectedError"));
00219 AddPreDefVal(TString("CostComplexity"));
00220
00221 DeclareOptionRef(fNNodesMax=100000,"NNodesMax","Max number of nodes in tree");
00222 if (DoRegression()) {
00223 DeclareOptionRef(fMaxDepth=50,"MaxDepth","Max depth of the decision tree allowed");
00224 }else{
00225 DeclareOptionRef(fMaxDepth=3,"MaxDepth","Max depth of the decision tree allowed");
00226 }
00227 }
00228
00229
00230 void TMVA::MethodDT::ProcessOptions()
00231 {
00232
00233 fSepTypeS.ToLower();
00234 if (fSepTypeS == "misclassificationerror") fSepType = new MisClassificationError();
00235 else if (fSepTypeS == "giniindex") fSepType = new GiniIndex();
00236 else if (fSepTypeS == "crossentropy") fSepType = new CrossEntropy();
00237 else if (fSepTypeS == "sdivsqrtsplusb") fSepType = new SdivSqrtSplusB();
00238 else {
00239 Log() << kINFO << GetOptions() << Endl;
00240 Log() << kFATAL << "<ProcessOptions> unknown Separation Index option called" << Endl;
00241 }
00242
00243
00244
00245 fPruneMethodS.ToLower();
00246 if (fPruneMethodS == "expectederror" ) fPruneMethod = DecisionTree::kExpectedErrorPruning;
00247 else if (fPruneMethodS == "costcomplexity" ) fPruneMethod = DecisionTree::kCostComplexityPruning;
00248 else if (fPruneMethodS == "nopruning" ) fPruneMethod = DecisionTree::kNoPruning;
00249 else {
00250 Log() << kINFO << GetOptions() << Endl;
00251 Log() << kFATAL << "<ProcessOptions> unknown PruneMethod option called" << Endl;
00252 }
00253
00254 if (fPruneStrength < 0) fAutomatic = kTRUE;
00255 else fAutomatic = kFALSE;
00256 if (fAutomatic && fPruneMethod==!DecisionTree::kCostComplexityPruning){
00257 Log() << kFATAL
00258 << "Sorry autmoatic pruning strength determination is not implemented yet for ExpectedErrorPruning" << Endl;
00259 }
00260
00261
00262 if (this->Data()->HasNegativeEventWeights()){
00263 Log() << kINFO << " You are using a Monte Carlo that has also negative weights. "
00264 << "That should in principle be fine as long as on average you end up with "
00265 << "something positive. For this you have to make sure that the minimal number "
00266 << "of (unweighted) events demanded for a tree node (currently you use: nEventsMin="
00267 <<fNodeMinEvents<<", you can set this via the BDT option string when booking the "
00268 << "classifier) is large enough to allow for reasonable averaging!!! "
00269 << " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining "
00270 << "which ignores events with negative weight in the training. " << Endl
00271 << Endl << "Note: You'll get a WARNING message during the training if that should ever happen" << Endl;
00272 }
00273
00274 if (fRandomisedTrees){
00275 Log() << kINFO << " Randomised trees should use *bagging* as *boost* method. Did you set this in the *MethodBoost* ? . Here I can enforce only the *no pruning*" << Endl;
00276 fPruneMethod = DecisionTree::kNoPruning;
00277
00278 }
00279
00280 }
00281
00282
00283 void TMVA::MethodDT::Init( void )
00284 {
00285
00286 fNodeMinEvents = TMath::Max( 20, int( Data()->GetNTrainingEvents() / (10*GetNvar()*GetNvar())) );
00287 fNCuts = 20;
00288 fPruneMethod = DecisionTree::kNoPruning;
00289 fPruneStrength = 5;
00290 fDeltaPruneStrength=0.1;
00291 fRandomisedTrees= kFALSE;
00292 fUseNvars = GetNvar();
00293
00294
00295 SetSignalReferenceCut( 0 );
00296 if (fAnalysisType == Types::kClassification || fAnalysisType == Types::kMulticlass ) {
00297 fMaxDepth = 3;
00298 }else {
00299 fMaxDepth = 50;
00300 }
00301 }
00302
00303
00304 TMVA::MethodDT::~MethodDT( void )
00305 {
00306
00307 delete fTree;
00308 }
00309
00310
00311 void TMVA::MethodDT::Train( void )
00312 {
00313 TMVA::DecisionTreeNode::fgIsTraining=true;
00314 fTree = new DecisionTree( fSepType, fNodeMinEvents, fNCuts, 0,
00315 fRandomisedTrees, fUseNvars, fNNodesMax, fMaxDepth,0 );
00316 if (fRandomisedTrees) Log()<<kWARNING<<" randomised Trees do not work yet in this framework,"
00317 << " as I do not know how to give each tree a new random seed, now they"
00318 << " will be all the same and that is not good " << Endl;
00319 fTree->SetAnalysisType( GetAnalysisType() );
00320
00321 fTree->BuildTree(GetEventCollection(Types::kTraining));
00322 TMVA::DecisionTreeNode::fgIsTraining=false;
00323 }
00324
00325
00326 Bool_t TMVA::MethodDT::MonitorBoost( MethodBoost* booster )
00327 {
00328 Int_t methodIndex = booster->GetMethodIndex();
00329 if (booster->GetBoostStage() == Types::kBoostProcBegin)
00330 {
00331 booster->AddMonitoringHist(new TH1I("NodesBeforePruning","nodes before pruning",booster->GetBoostNum(),0,booster->GetBoostNum()));
00332 booster->AddMonitoringHist(new TH1I("NodesAfterPruning","nodes after pruning",booster->GetBoostNum(),0,booster->GetBoostNum()));
00333 booster->AddMonitoringHist(new TH1D("PruneStrength","prune strength",booster->GetBoostNum(),0,booster->GetBoostNum()));
00334 }
00335
00336 if (booster->GetBoostStage() == Types::kBeforeTraining)
00337 {
00338 if (methodIndex == 0)
00339 {
00340 booster->GetMonitoringHist(2)->SetXTitle("#tree");
00341 booster->GetMonitoringHist(2)->SetYTitle("PruneStrength");
00342
00343 if (fAutomatic)
00344 {
00345 Data()->DivideTrainingSet(2);
00346 Data()->MoveTrainingBlock(1,Types::kValidation,kTRUE);
00347 }
00348 }
00349 }
00350 else if (booster->GetBoostStage() == Types::kBeforeBoosting)
00351 booster->GetMonitoringHist(0)->SetBinContent(booster->GetBoostNum()+1,fTree->GetNNodes());
00352
00353 if (booster->GetBoostStage() == ((fPruneBeforeBoost)?Types::kBeforeBoosting:Types::kBoostValidation)
00354 && !(fPruneMethod == DecisionTree::kNoPruning)) {
00355
00356 if (methodIndex==0 && fPruneBeforeBoost == kFALSE)
00357 Log() << kINFO << "Pruning "<< booster->GetBoostNum() << " Decision Trees ... patience please" << Endl;
00358
00359
00360 if (fAutomatic && methodIndex > 0) {
00361 MethodDT* mdt = dynamic_cast<MethodDT*>(booster->GetPreviousMethod());
00362 if(mdt)
00363 fPruneStrength = mdt->GetPruneStrength();
00364 }
00365
00366 booster->GetMonitoringHist(0)->SetBinContent(methodIndex+1,fTree->GetNNodes());
00367 booster->GetMonitoringHist(2)->SetBinContent(methodIndex+1,PruneTree(methodIndex));
00368 booster->GetMonitoringHist(1)->SetBinContent(methodIndex+1,fTree->GetNNodes());
00369 }
00370 else if (booster->GetBoostStage() != Types::kBoostProcEnd)
00371 return kFALSE;
00372
00373
00374 if (booster->GetBoostStage() == Types::kBoostProcEnd)
00375 {
00376 if (fPruneMethod == DecisionTree::kNoPruning) {
00377 Log() << kINFO << "<Train> average number of nodes (w/o pruning) : "
00378 << booster->GetMonitoringHist(0)->GetMean() << Endl;
00379 }
00380 else
00381 {
00382 Log() << kINFO << "<Train> average number of nodes before/after pruning : "
00383 << booster->GetMonitoringHist(0)->GetMean() << " / "
00384 << booster->GetMonitoringHist(1)->GetMean()
00385 << Endl;
00386 }
00387 }
00388
00389 return kTRUE;
00390 }
00391
00392
00393
00394 Double_t TMVA::MethodDT::PruneTree(const Int_t methodIndex)
00395 {
00396 if (fAutomatic && fPruneMethod == DecisionTree::kCostComplexityPruning) {
00397 CCPruner* pruneTool = new CCPruner(fTree, this->Data() , fSepType);
00398 pruneTool->Optimize();
00399 std::vector<DecisionTreeNode*> nodes = pruneTool->GetOptimalPruneSequence();
00400 fPruneStrength = pruneTool->GetOptimalPruneStrength();
00401 for(UInt_t i = 0; i < nodes.size(); i++)
00402 fTree->PruneNode(nodes[i]);
00403 delete pruneTool;
00404 }
00405 else if (fAutomatic && fPruneMethod != DecisionTree::kCostComplexityPruning){
00406 Int_t bla;
00407 bla = methodIndex;
00408
00409
00410
00411
00412
00413
00414
00415
00416
00417
00418
00419
00420
00421
00422
00423
00424
00425
00426
00427
00428
00429
00430
00431
00432
00433
00434
00435
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449
00450
00451
00452
00453
00454
00455
00456
00457
00458
00459
00460
00461
00462
00463
00464
00465
00466
00467
00468
00469
00470
00471
00472
00473
00474
00475
00476
00477
00478
00479
00480
00481 }
00482 else {
00483 fTree->SetPruneStrength(fPruneStrength);
00484 fTree->PruneTree();
00485 }
00486 return fPruneStrength;
00487 }
00488
00489
00490 Double_t TMVA::MethodDT::TestTreeQuality( DecisionTree *dt )
00491 {
00492 Data()->SetCurrentType(Types::kValidation);
00493
00494 Double_t SumCorrect=0,SumWrong=0;
00495 for (Long64_t ievt=0; ievt<Data()->GetNEvents(); ievt++)
00496 {
00497 Event * ev = Data()->GetEvent(ievt);
00498 if ((dt->CheckEvent(*ev) > dt->GetNodePurityLimit() ) == DataInfo().IsSignal(ev)) SumCorrect+=ev->GetWeight();
00499 else SumWrong+=ev->GetWeight();
00500 }
00501 Data()->SetCurrentType(Types::kTraining);
00502 return SumCorrect / (SumCorrect + SumWrong);
00503 }
00504
00505
00506 void TMVA::MethodDT::AddWeightsXMLTo( void* parent ) const
00507 {
00508 fTree->AddXMLTo(parent);
00509
00510 }
00511
00512
00513 void TMVA::MethodDT::ReadWeightsFromXML( void* wghtnode)
00514 {
00515 if(fTree)
00516 delete fTree;
00517 fTree = new DecisionTree();
00518 fTree->ReadXML(wghtnode,GetTrainingTMVAVersionCode());
00519 }
00520
00521
00522 void TMVA::MethodDT::ReadWeightsFromStream( istream& istr )
00523 {
00524 delete fTree;
00525 fTree = new DecisionTree();
00526 fTree->Read(istr);
00527 }
00528
00529
00530 Double_t TMVA::MethodDT::GetMvaValue( Double_t* err, Double_t* errUpper )
00531 {
00532
00533
00534
00535 NoErrorCalc(err, errUpper);
00536
00537 return fTree->CheckEvent(*GetEvent(),fUseYesNoLeaf);
00538 }
00539
00540
00541 void TMVA::MethodDT::GetHelpMessage() const
00542 {
00543
00544 }
00545
00546 const TMVA::Ranking* TMVA::MethodDT::CreateRanking()
00547 {
00548 return 0;
00549 }