00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 #include <iostream>
00056 #include <algorithm>
00057 #include <vector>
00058 #include <limits>
00059 #include <fstream>
00060 #include <algorithm>
00061 #include <cassert>
00062
00063 #include "TRandom3.h"
00064 #include "TMath.h"
00065 #include "TMatrix.h"
00066
00067 #include "TMVA/MsgLogger.h"
00068 #include "TMVA/DecisionTree.h"
00069 #include "TMVA/DecisionTreeNode.h"
00070 #include "TMVA/BinarySearchTree.h"
00071
00072 #include "TMVA/Tools.h"
00073
00074 #include "TMVA/GiniIndex.h"
00075 #include "TMVA/CrossEntropy.h"
00076 #include "TMVA/MisClassificationError.h"
00077 #include "TMVA/SdivSqrtSplusB.h"
00078 #include "TMVA/Event.h"
00079 #include "TMVA/BDTEventWrapper.h"
00080 #include "TMVA/IPruneTool.h"
00081 #include "TMVA/CostComplexityPruneTool.h"
00082 #include "TMVA/ExpectedErrorPruneTool.h"
00083
00084 const Int_t TMVA::DecisionTree::fgRandomSeed = 0;
00085
00086 using std::vector;
00087
00088 ClassImp(TMVA::DecisionTree)
00089
00090
00091 TMVA::DecisionTree::DecisionTree():
00092 BinaryTree(),
00093 fNvars (0),
00094 fNCuts (-1),
00095 fUseFisherCuts (kFALSE),
00096 fMinLinCorrForFisher (1),
00097 fUseExclusiveVars (kTRUE),
00098 fSepType (NULL),
00099 fRegType (NULL),
00100 fMinSize (0),
00101 fMinSepGain (0),
00102 fUseSearchTree(kFALSE),
00103 fPruneStrength(0),
00104 fPruneMethod (kNoPruning),
00105 fNodePurityLimit(0.5),
00106 fRandomisedTree (kFALSE),
00107 fUseNvars (0),
00108 fUsePoissonNvars(kFALSE),
00109 fMyTrandom (NULL),
00110 fNNodesMax (999999),
00111 fMaxDepth (999999),
00112 fClass (0),
00113 fTreeID (0),
00114 fAnalysisType (Types::kClassification)
00115 {
00116
00117
00118
00119 }
00120
00121
00122 TMVA::DecisionTree::DecisionTree( TMVA::SeparationBase *sepType, Int_t minSize, Int_t nCuts, UInt_t cls,
00123 Bool_t randomisedTree, Int_t useNvars, Bool_t usePoissonNvars, UInt_t nNodesMax,
00124 UInt_t nMaxDepth, Int_t iSeed, Float_t purityLimit, Int_t treeID):
00125 BinaryTree(),
00126 fNvars (0),
00127 fNCuts (nCuts),
00128 fUseFisherCuts (kFALSE),
00129 fMinLinCorrForFisher (1),
00130 fUseExclusiveVars (kTRUE),
00131 fSepType (sepType),
00132 fRegType (NULL),
00133 fMinSize (minSize),
00134 fMinSepGain (0),
00135 fUseSearchTree (kFALSE),
00136 fPruneStrength (0),
00137 fPruneMethod (kNoPruning),
00138 fNodePurityLimit(purityLimit),
00139 fRandomisedTree (randomisedTree),
00140 fUseNvars (useNvars),
00141 fUsePoissonNvars(usePoissonNvars),
00142 fMyTrandom (new TRandom3(iSeed)),
00143 fNNodesMax (nNodesMax),
00144 fMaxDepth (nMaxDepth),
00145 fClass (cls),
00146 fTreeID (treeID)
00147 {
00148
00149
00150
00151
00152
00153 if (sepType == NULL) {
00154
00155
00156 fAnalysisType = Types::kRegression;
00157 fRegType = new RegressionVariance();
00158 if ( nCuts <=0 ) {
00159 fNCuts = 200;
00160 Log() << kWARNING << " You had choosen the training mode using optimal cuts, not\n"
00161 << " based on a grid of " << fNCuts << " by setting the option NCuts < 0\n"
00162 << " as this doesn't exist yet, I set it to " << fNCuts << " and use the grid"
00163 << Endl;
00164 }
00165 }else{
00166 fAnalysisType = Types::kClassification;
00167 }
00168 }
00169
00170
00171 TMVA::DecisionTree::DecisionTree( const DecisionTree &d ):
00172 BinaryTree(),
00173 fNvars (d.fNvars),
00174 fNCuts (d.fNCuts),
00175 fUseFisherCuts (d.fUseFisherCuts),
00176 fMinLinCorrForFisher (d.fMinLinCorrForFisher),
00177 fUseExclusiveVars (d.fUseExclusiveVars),
00178 fSepType (d.fSepType),
00179 fRegType (d.fRegType),
00180 fMinSize (d.fMinSize),
00181 fMinSepGain (d.fMinSepGain),
00182 fUseSearchTree (d.fUseSearchTree),
00183 fPruneStrength (d.fPruneStrength),
00184 fPruneMethod (d.fPruneMethod),
00185 fNodePurityLimit(d.fNodePurityLimit),
00186 fRandomisedTree (d.fRandomisedTree),
00187 fUseNvars (d.fUseNvars),
00188 fUsePoissonNvars(d.fUsePoissonNvars),
00189 fMyTrandom (new TRandom3(fgRandomSeed)),
00190 fNNodesMax (d.fNNodesMax),
00191 fMaxDepth (d.fMaxDepth),
00192 fClass (d.fClass),
00193 fTreeID (d.fTreeID),
00194 fAnalysisType(d.fAnalysisType)
00195 {
00196
00197
00198 this->SetRoot( new TMVA::DecisionTreeNode ( *((DecisionTreeNode*)(d.GetRoot())) ) );
00199 this->SetParentTreeInNodes();
00200 fNNodes = d.fNNodes;
00201 }
00202
00203
00204
00205 TMVA::DecisionTree::~DecisionTree()
00206 {
00207
00208
00209
00210
00211 if (fMyTrandom) delete fMyTrandom;
00212 }
00213
00214
00215 void TMVA::DecisionTree::SetParentTreeInNodes( Node *n )
00216 {
00217
00218
00219
00220 if (n == NULL) {
00221 n = this->GetRoot();
00222 if (n == NULL) {
00223 Log() << kFATAL << "SetParentTreeNodes: started with undefined ROOT node" <<Endl;
00224 return ;
00225 }
00226 }
00227
00228 if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
00229 Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00230 return;
00231 } else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
00232 Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00233 return;
00234 }
00235 else {
00236 if (this->GetLeftDaughter(n) != NULL) {
00237 this->SetParentTreeInNodes( this->GetLeftDaughter(n) );
00238 }
00239 if (this->GetRightDaughter(n) != NULL) {
00240 this->SetParentTreeInNodes( this->GetRightDaughter(n) );
00241 }
00242 }
00243 n->SetParentTree(this);
00244 if (n->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(n->GetDepth());
00245 return;
00246 }
00247
00248
00249 TMVA::DecisionTree* TMVA::DecisionTree::CreateFromXML(void* node, UInt_t tmva_Version_Code ) {
00250
00251 std::string type("");
00252 gTools().ReadAttr(node,"type", type);
00253 DecisionTree* dt = new DecisionTree();
00254
00255 dt->ReadXML( node, tmva_Version_Code );
00256 return dt;
00257 }
00258
00259
00260
00261 UInt_t TMVA::DecisionTree::BuildTree( const vector<TMVA::Event*> & eventSample,
00262 TMVA::DecisionTreeNode *node)
00263 {
00264
00265
00266
00267 Bool_t IsRootNode=kFALSE;
00268 if (node==NULL) {
00269 IsRootNode = kTRUE;
00270
00271 node = new TMVA::DecisionTreeNode();
00272 fNNodes = 1;
00273 this->SetRoot(node);
00274
00275 this->GetRoot()->SetPos('s');
00276 this->GetRoot()->SetDepth(0);
00277 this->GetRoot()->SetParentTree(this);
00278 }
00279
00280 UInt_t nevents = eventSample.size();
00281
00282 if (nevents > 0 ) {
00283 fNvars = eventSample[0]->GetNVariables();
00284 fVariableImportance.resize(fNvars);
00285 }
00286 else Log() << kFATAL << ":<BuildTree> eventsample Size == 0 " << Endl;
00287
00288 Double_t s=0, b=0;
00289 Double_t suw=0, buw=0;
00290 Double_t target=0, target2=0;
00291 Float_t *xmin = new Float_t[fNvars];
00292 Float_t *xmax = new Float_t[fNvars];
00293 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
00294 xmin[ivar]=xmax[ivar]=0;
00295 }
00296 for (UInt_t iev=0; iev<eventSample.size(); iev++) {
00297 const TMVA::Event* evt = eventSample[iev];
00298 const Double_t weight = evt->GetWeight();
00299 if (evt->GetClass() == fClass) {
00300 s += weight;
00301 suw += 1;
00302 }
00303 else {
00304 b += weight;
00305 buw += 1;
00306 }
00307 if ( DoRegression() ) {
00308 const Double_t tgt = evt->GetTarget(0);
00309 target +=weight*tgt;
00310 target2+=weight*tgt*tgt;
00311 }
00312
00313 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
00314 const Double_t val = evt->GetValue(ivar);
00315 if (iev==0) xmin[ivar]=xmax[ivar]=val;
00316 if (val < xmin[ivar]) xmin[ivar]=val;
00317 if (val > xmax[ivar]) xmax[ivar]=val;
00318 }
00319 }
00320
00321 if (s+b < 0) {
00322 Log() << kWARNING << " One of the Decision Tree nodes has negative total number of signal or background events. "
00323 << "(Nsig="<<s<<" Nbkg="<<b<<" Probaby you use a Monte Carlo with negative weights. That should in principle "
00324 << "be fine as long as on average you end up with something positive. For this you have to make sure that the "
00325 << "minimul number of (unweighted) events demanded for a tree node (currently you use: nEventsMin="<<fMinSize
00326 << ", you can set this via the BDT option string when booking the classifier) is large enough to allow for "
00327 << "reasonable averaging!!!" << Endl
00328 << " If this does not help.. maybe you want to try the option: NoNegWeightsInTraining which ignores events "
00329 << "with negative weight in the training." << Endl;
00330 double nBkg=0.;
00331 for (UInt_t i=0; i<eventSample.size(); i++) {
00332 if (eventSample[i]->GetClass() != fClass) {
00333 nBkg += eventSample[i]->GetWeight();
00334 Log() << kINFO << "Event "<< i<< " has (original) weight: " << eventSample[i]->GetWeight()/eventSample[i]->GetBoostWeight()
00335 << " boostWeight: " << eventSample[i]->GetBoostWeight() << Endl;
00336 }
00337 }
00338 Log() << kINFO << " that gives in total: " << nBkg<<Endl;
00339 }
00340
00341 node->SetNSigEvents(s);
00342 node->SetNBkgEvents(b);
00343 node->SetNSigEvents_unweighted(suw);
00344 node->SetNBkgEvents_unweighted(buw);
00345 node->SetPurity();
00346 if (node == this->GetRoot()) {
00347 node->SetNEvents(s+b);
00348 node->SetNEvents_unweighted(suw+buw);
00349 }
00350 for (UInt_t ivar=0; ivar<fNvars; ivar++) {
00351 node->SetSampleMin(ivar,xmin[ivar]);
00352 node->SetSampleMax(ivar,xmax[ivar]);
00353 }
00354 delete[] xmin;
00355 delete[] xmax;
00356
00357
00358
00359
00360
00361 if (eventSample.size() >= 2*fMinSize && fNNodes < fNNodesMax && node->GetDepth() < fMaxDepth
00362 && ( ( s!=0 && b !=0 && !DoRegression()) || ( (s+b)!=0 && DoRegression()) ) ) {
00363 Double_t separationGain;
00364 if (fNCuts > 0){
00365 separationGain = this->TrainNodeFast(eventSample, node);
00366 } else {
00367 separationGain = this->TrainNodeFull(eventSample, node);
00368 }
00369 if (separationGain < std::numeric_limits<double>::epsilon()) {
00370
00371
00372
00373 if (DoRegression()) {
00374 node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
00375 node->SetResponse(target/(s+b));
00376 node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
00377 }
00378 else {
00379 node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
00380 }
00381 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
00382 else node->SetNodeType(-1);
00383 if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
00384
00385 } else {
00386
00387 vector<TMVA::Event*> leftSample; leftSample.reserve(nevents);
00388 vector<TMVA::Event*> rightSample; rightSample.reserve(nevents);
00389
00390 Double_t nRight=0, nLeft=0;
00391
00392 for (UInt_t ie=0; ie< nevents ; ie++) {
00393 if (node->GoesRight(*eventSample[ie])) {
00394 rightSample.push_back(eventSample[ie]);
00395 nRight += eventSample[ie]->GetWeight();
00396 }
00397 else {
00398 leftSample.push_back(eventSample[ie]);
00399 nLeft += eventSample[ie]->GetWeight();
00400 }
00401 }
00402
00403
00404 if (leftSample.size() == 0 || rightSample.size() == 0) {
00405 Log() << kFATAL << "<TrainNode> all events went to the same branch" << Endl
00406 << "--- Hence new node == old node ... check" << Endl
00407 << "--- left:" << leftSample.size()
00408 << " right:" << rightSample.size() << Endl
00409 << "--- this should never happen, please write a bug report to Helge.Voss@cern.ch"
00410 << Endl;
00411 }
00412
00413
00414 TMVA::DecisionTreeNode *rightNode = new TMVA::DecisionTreeNode(node,'r');
00415 fNNodes++;
00416 rightNode->SetNEvents(nRight);
00417 rightNode->SetNEvents_unweighted(rightSample.size());
00418
00419 TMVA::DecisionTreeNode *leftNode = new TMVA::DecisionTreeNode(node,'l');
00420
00421 fNNodes++;
00422 leftNode->SetNEvents(nLeft);
00423 leftNode->SetNEvents_unweighted(leftSample.size());
00424
00425 node->SetNodeType(0);
00426 node->SetLeft(leftNode);
00427 node->SetRight(rightNode);
00428
00429 this->BuildTree(rightSample, rightNode);
00430 this->BuildTree(leftSample, leftNode );
00431 }
00432 }
00433 else{
00434 if (DoRegression()) {
00435 node->SetSeparationIndex(fRegType->GetSeparationIndex(s+b,target,target2));
00436 node->SetResponse(target/(s+b));
00437 node->SetRMS(TMath::Sqrt(target2/(s+b) - target/(s+b)*target/(s+b)));
00438 }
00439 else {
00440 node->SetSeparationIndex(fSepType->GetSeparationIndex(s,b));
00441 }
00442
00443 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
00444 else node->SetNodeType(-1);
00445
00446 if (node->GetDepth() > this->GetTotalTreeDepth()) this->SetTotalTreeDepth(node->GetDepth());
00447 }
00448
00449
00450 return fNNodes;
00451 }
00452
00453
00454 void TMVA::DecisionTree::FillTree( vector<TMVA::Event*> & eventSample )
00455
00456 {
00457
00458
00459 for (UInt_t i=0; i<eventSample.size(); i++) {
00460 this->FillEvent(*(eventSample[i]),NULL);
00461 }
00462 }
00463
00464
00465 void TMVA::DecisionTree::FillEvent( TMVA::Event & event,
00466 TMVA::DecisionTreeNode *node )
00467 {
00468
00469
00470
00471 if (node == NULL) {
00472 node = this->GetRoot();
00473 }
00474
00475 node->IncrementNEvents( event.GetWeight() );
00476 node->IncrementNEvents_unweighted( );
00477
00478 if (event.GetClass() == fClass) {
00479 node->IncrementNSigEvents( event.GetWeight() );
00480 node->IncrementNSigEvents_unweighted( );
00481 }
00482 else {
00483 node->IncrementNBkgEvents( event.GetWeight() );
00484 node->IncrementNBkgEvents_unweighted( );
00485 }
00486 node->SetSeparationIndex(fSepType->GetSeparationIndex(node->GetNSigEvents(),
00487 node->GetNBkgEvents()));
00488
00489 if (node->GetNodeType() == 0) {
00490 if (node->GoesRight(event))
00491 this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetRight())) ;
00492 else
00493 this->FillEvent(event,dynamic_cast<TMVA::DecisionTreeNode*>(node->GetLeft())) ;
00494 }
00495
00496
00497 }
00498
00499
00500 void TMVA::DecisionTree::ClearTree()
00501 {
00502
00503
00504 if (this->GetRoot()!=NULL) this->GetRoot()->ClearNodeAndAllDaughters();
00505
00506 }
00507
00508
00509 UInt_t TMVA::DecisionTree::CleanTree( DecisionTreeNode *node )
00510 {
00511
00512
00513
00514
00515
00516
00517
00518 if (node==NULL) {
00519 node = this->GetRoot();
00520 }
00521
00522 DecisionTreeNode *l = node->GetLeft();
00523 DecisionTreeNode *r = node->GetRight();
00524
00525 if (node->GetNodeType() == 0) {
00526 this->CleanTree(l);
00527 this->CleanTree(r);
00528 if (l->GetNodeType() * r->GetNodeType() > 0) {
00529
00530 this->PruneNode(node);
00531 }
00532 }
00533
00534 return this->CountNodes();
00535
00536 }
00537
00538
00539 Double_t TMVA::DecisionTree::PruneTree( vector<Event*>* validationSample )
00540 {
00541
00542
00543
00544
00545
00546
00547 IPruneTool* tool(NULL);
00548 PruningInfo* info(NULL);
00549
00550 if( fPruneMethod == kNoPruning ) return 0.0;
00551
00552 if (fPruneMethod == kExpectedErrorPruning)
00553
00554 tool = new ExpectedErrorPruneTool();
00555 else if (fPruneMethod == kCostComplexityPruning)
00556 {
00557 tool = new CostComplexityPruneTool();
00558 }
00559 else {
00560 Log() << kFATAL << "Selected pruning method not yet implemented "
00561 << Endl;
00562 }
00563 if(!tool) return 0.0;
00564
00565 tool->SetPruneStrength(GetPruneStrength());
00566 if(tool->IsAutomatic()) {
00567 if(validationSample == NULL)
00568 Log() << kFATAL << "Cannot automate the pruning algorithm without an "
00569 << "independent validation sample!" << Endl;
00570 if(validationSample->size() == 0)
00571 Log() << kFATAL << "Cannot automate the pruning algorithm with "
00572 << "independent validation sample of ZERO events!" << Endl;
00573 }
00574
00575 info = tool->CalculatePruningInfo(this,validationSample);
00576 if(!info) {
00577 delete tool;
00578 Log() << kFATAL << "Error pruning tree! Check prune.log for more information."
00579 << Endl;
00580 }
00581 Double_t pruneStrength = info->PruneStrength;
00582
00583
00584
00585
00586
00587 for (UInt_t i = 0; i < info->PruneSequence.size(); ++i) {
00588
00589 PruneNode(info->PruneSequence[i]);
00590 }
00591
00592 this->CountNodes();
00593
00594 delete tool;
00595 delete info;
00596
00597 return pruneStrength;
00598 };
00599
00600
00601
00602 void TMVA::DecisionTree::ApplyValidationSample( const EventList* validationSample ) const
00603 {
00604
00605
00606
00607
00608 GetRoot()->ResetValidationData();
00609 for (UInt_t ievt=0; ievt < validationSample->size(); ievt++) {
00610 CheckEventWithPrunedTree(*(*validationSample)[ievt]);
00611 }
00612 }
00613
00614
00615 Double_t TMVA::DecisionTree::TestPrunedTreeQuality( const DecisionTreeNode* n, Int_t mode ) const
00616 {
00617
00618
00619
00620
00621
00622 if (n == NULL) {
00623 n = this->GetRoot();
00624 if (n == NULL) {
00625 Log() << kFATAL << "TestPrunedTreeQuality: started with undefined ROOT node" <<Endl;
00626 return 0;
00627 }
00628 }
00629
00630 if( n->GetLeft() != NULL && n->GetRight() != NULL && !n->IsTerminal() ) {
00631 return (TestPrunedTreeQuality( n->GetLeft(), mode ) +
00632 TestPrunedTreeQuality( n->GetRight(), mode ));
00633 }
00634 else {
00635 if (DoRegression()) {
00636 Double_t sumw = n->GetNSValidation() + n->GetNBValidation();
00637 return n->GetSumTarget2() - 2*n->GetSumTarget()*n->GetResponse() + sumw*n->GetResponse()*n->GetResponse();
00638 }
00639 else {
00640 if (mode == 0) {
00641 if (n->GetPurity() > this->GetNodePurityLimit())
00642 return n->GetNBValidation();
00643 else
00644 return n->GetNSValidation();
00645 }
00646 else if ( mode == 1 ) {
00647
00648 return (n->GetPurity() * n->GetNBValidation() + (1.0 - n->GetPurity()) * n->GetNSValidation());
00649 }
00650 else {
00651 throw std::string("Unknown ValidationQualityMode");
00652 }
00653 }
00654 }
00655 }
00656
00657
00658 void TMVA::DecisionTree::CheckEventWithPrunedTree( const Event& e ) const
00659 {
00660
00661
00662
00663
00664 DecisionTreeNode* current = this->GetRoot();
00665 if (current == NULL) {
00666 Log() << kFATAL << "CheckEventWithPrunedTree: started with undefined ROOT node" <<Endl;
00667 }
00668
00669 while(current != NULL) {
00670 if(e.GetClass() == fClass)
00671 current->SetNSValidation(current->GetNSValidation() + e.GetWeight());
00672 else
00673 current->SetNBValidation(current->GetNBValidation() + e.GetWeight());
00674
00675 if (e.GetNTargets() > 0) {
00676 current->AddToSumTarget(e.GetWeight()*e.GetTarget(0));
00677 current->AddToSumTarget2(e.GetWeight()*e.GetTarget(0)*e.GetTarget(0));
00678 }
00679
00680 if (current->GetRight() == NULL || current->GetLeft() == NULL) {
00681 current = NULL;
00682 }
00683 else {
00684 if (current->GoesRight(e))
00685 current = (TMVA::DecisionTreeNode*)current->GetRight();
00686 else
00687 current = (TMVA::DecisionTreeNode*)current->GetLeft();
00688 }
00689 }
00690 }
00691
00692
00693 Double_t TMVA::DecisionTree::GetSumWeights( const EventList* validationSample ) const
00694 {
00695
00696 Double_t sumWeights = 0.0;
00697 for( EventList::const_iterator it = validationSample->begin();
00698 it != validationSample->end(); ++it ) {
00699 sumWeights += (*it)->GetWeight();
00700 }
00701 return sumWeights;
00702 }
00703
00704
00705
00706
00707 UInt_t TMVA::DecisionTree::CountLeafNodes( TMVA::Node *n )
00708 {
00709
00710
00711 if (n == NULL) {
00712 n = this->GetRoot();
00713 if (n == NULL) {
00714 Log() << kFATAL << "CountLeafNodes: started with undefined ROOT node" <<Endl;
00715 return 0;
00716 }
00717 }
00718
00719 UInt_t countLeafs=0;
00720
00721 if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
00722 countLeafs += 1;
00723 }
00724 else {
00725 if (this->GetLeftDaughter(n) != NULL) {
00726 countLeafs += this->CountLeafNodes( this->GetLeftDaughter(n) );
00727 }
00728 if (this->GetRightDaughter(n) != NULL) {
00729 countLeafs += this->CountLeafNodes( this->GetRightDaughter(n) );
00730 }
00731 }
00732 return countLeafs;
00733 }
00734
00735
00736 void TMVA::DecisionTree::DescendTree( Node* n )
00737 {
00738
00739
00740 if (n == NULL) {
00741 n = this->GetRoot();
00742 if (n == NULL) {
00743 Log() << kFATAL << "DescendTree: started with undefined ROOT node" <<Endl;
00744 return ;
00745 }
00746 }
00747
00748 if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) == NULL) ) {
00749
00750 }
00751 else if ((this->GetLeftDaughter(n) == NULL) && (this->GetRightDaughter(n) != NULL) ) {
00752 Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00753 return;
00754 }
00755 else if ((this->GetLeftDaughter(n) != NULL) && (this->GetRightDaughter(n) == NULL) ) {
00756 Log() << kFATAL << " Node with only one daughter?? Something went wrong" << Endl;
00757 return;
00758 }
00759 else {
00760 if (this->GetLeftDaughter(n) != NULL) {
00761 this->DescendTree( this->GetLeftDaughter(n) );
00762 }
00763 if (this->GetRightDaughter(n) != NULL) {
00764 this->DescendTree( this->GetRightDaughter(n) );
00765 }
00766 }
00767 }
00768
00769
00770 void TMVA::DecisionTree::PruneNode( DecisionTreeNode* node )
00771 {
00772
00773 DecisionTreeNode *l = node->GetLeft();
00774 DecisionTreeNode *r = node->GetRight();
00775
00776 node->SetRight(NULL);
00777 node->SetLeft(NULL);
00778 node->SetSelector(-1);
00779 node->SetSeparationGain(-1);
00780 if (node->GetPurity() > fNodePurityLimit) node->SetNodeType(1);
00781 else node->SetNodeType(-1);
00782 this->DeleteNode(l);
00783 this->DeleteNode(r);
00784
00785 this->CountNodes();
00786
00787 }
00788
00789
00790 void TMVA::DecisionTree::PruneNodeInPlace( DecisionTreeNode* node ) {
00791
00792
00793
00794
00795 if(node == NULL) return;
00796 node->SetNTerminal(1);
00797 node->SetSubTreeR( node->GetNodeR() );
00798 node->SetAlpha( std::numeric_limits<double>::infinity( ) );
00799 node->SetAlphaMinSubtree( std::numeric_limits<double>::infinity( ) );
00800 node->SetTerminal(kTRUE);
00801 }
00802
00803
00804 TMVA::Node* TMVA::DecisionTree::GetNode( ULong_t sequence, UInt_t depth )
00805 {
00806
00807
00808
00809
00810 Node* current = this->GetRoot();
00811
00812 for (UInt_t i =0; i < depth; i++) {
00813 ULong_t tmp = 1 << i;
00814 if ( tmp & sequence) current = this->GetRightDaughter(current);
00815 else current = this->GetLeftDaughter(current);
00816 }
00817
00818 return current;
00819 }
00820
00821
00822
00823 void TMVA::DecisionTree::GetRandomisedVariables(Bool_t *useVariable, UInt_t *mapVariable, UInt_t &useNvars){
00824
00825 for (UInt_t ivar=0; ivar<fNvars; ivar++) useVariable[ivar]=kFALSE;
00826 if (fUseNvars==0) {
00827
00828 fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
00829 }
00830 if (fUsePoissonNvars) useNvars=TMath::Min(fNvars,TMath::Max(UInt_t(1),(UInt_t) fMyTrandom->Poisson(fUseNvars)));
00831 else useNvars = fUseNvars;
00832
00833 UInt_t nSelectedVars = 0;
00834 while (nSelectedVars < useNvars) {
00835 Double_t bla = fMyTrandom->Rndm()*fNvars;
00836 useVariable[Int_t (bla)] = kTRUE;
00837 nSelectedVars = 0;
00838 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00839 if (useVariable[ivar] == kTRUE) {
00840 mapVariable[nSelectedVars] = ivar;
00841 nSelectedVars++;
00842 }
00843 }
00844 }
00845 if (nSelectedVars != useNvars) { std::cout << "Bug in TrainNode - GetRandisedVariables()... sorry" << std::endl; std::exit(1);}
00846 }
00847
00848
00849 Double_t TMVA::DecisionTree::TrainNodeFast( const vector<TMVA::Event*> & eventSample,
00850 TMVA::DecisionTreeNode *node )
00851 {
00852
00853
00854
00855
00856
00857
00858
00859
00860 Double_t separationGain = -1, sepTmp;
00861 Double_t cutValue=-999;
00862 Int_t mxVar= -1;
00863 Int_t cutIndex=-1;
00864 Bool_t cutType=kTRUE;
00865 Double_t nTotS, nTotB;
00866 Int_t nTotS_unWeighted, nTotB_unWeighted;
00867 UInt_t nevents = eventSample.size();
00868
00869
00870
00871
00872 Bool_t *useVariable = new Bool_t[fNvars+1];
00873 UInt_t *mapVariable = new UInt_t[fNvars+1];
00874
00875 std::vector<Double_t> fisherCoeff;
00876
00877 if (fRandomisedTree) {
00878 UInt_t tmp=fUseNvars;
00879 GetRandomisedVariables(useVariable,mapVariable,tmp);
00880 }
00881 else {
00882 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00883 useVariable[ivar] = kTRUE;
00884 mapVariable[ivar] = ivar;
00885 }
00886 }
00887 useVariable[fNvars] = kFALSE;
00888
00889 if (fUseFisherCuts) {
00890 useVariable[fNvars] = kTRUE;
00891
00892
00893
00894 Bool_t *useVarInFisher = new Bool_t[fNvars];
00895 UInt_t *mapVarInFisher = new UInt_t[fNvars];
00896 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00897 useVarInFisher[ivar] = kFALSE;
00898 mapVarInFisher[ivar] = ivar;
00899 }
00900
00901 std::vector<TMatrixDSym*>* covMatrices;
00902 covMatrices = gTools().CalcCovarianceMatrices( eventSample, 2 );
00903 TMatrixD *ss = new TMatrixD(*(covMatrices->at(0)));
00904 TMatrixD *bb = new TMatrixD(*(covMatrices->at(1)));
00905 const TMatrixD *s = gTools().GetCorrelationMatrix(ss);
00906 const TMatrixD *b = gTools().GetCorrelationMatrix(bb);
00907
00908 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00909 for (UInt_t jvar=ivar+1; jvar < fNvars; jvar++) {
00910 if ( ( TMath::Abs( (*s)(ivar, jvar)) > fMinLinCorrForFisher) ||
00911 ( TMath::Abs( (*b)(ivar, jvar)) > fMinLinCorrForFisher) ){
00912 useVarInFisher[ivar] = kTRUE;
00913 useVarInFisher[jvar] = kTRUE;
00914 }
00915 }
00916 }
00917
00918
00919
00920 UInt_t nFisherVars = 0;
00921 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
00922
00923
00924 if (useVarInFisher[ivar] && useVariable[ivar]) {
00925 mapVarInFisher[nFisherVars++]=ivar;
00926
00927
00928 if (fUseExclusiveVars) useVariable[ivar] = kFALSE;
00929 }
00930 }
00931
00932
00933 fisherCoeff = this->GetFisherCoefficients(eventSample, nFisherVars, mapVarInFisher);
00934 delete [] useVarInFisher;
00935 delete [] mapVarInFisher;
00936 }
00937
00938
00939 const UInt_t nBins = fNCuts+1;
00940 UInt_t cNvars = fNvars;
00941 if (fUseFisherCuts) cNvars++;
00942
00943 Double_t** nSelS = new Double_t* [cNvars];
00944 Double_t** nSelB = new Double_t* [cNvars];
00945 Double_t** nSelS_unWeighted = new Double_t* [cNvars];
00946 Double_t** nSelB_unWeighted = new Double_t* [cNvars];
00947 Double_t** target = new Double_t* [cNvars];
00948 Double_t** target2 = new Double_t* [cNvars];
00949 Double_t** cutValues = new Double_t* [cNvars];
00950
00951 for (UInt_t i=0; i<cNvars; i++) {
00952 nSelS[i] = new Double_t [nBins];
00953 nSelB[i] = new Double_t [nBins];
00954 nSelS_unWeighted[i] = new Double_t [nBins];
00955 nSelB_unWeighted[i] = new Double_t [nBins];
00956 target[i] = new Double_t [nBins];
00957 target2[i] = new Double_t [nBins];
00958 cutValues[i] = new Double_t [nBins];
00959 }
00960
00961 Double_t *xmin = new Double_t[cNvars];
00962 Double_t *xmax = new Double_t[cNvars];
00963
00964 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
00965 if (ivar < fNvars){
00966 xmin[ivar]=node->GetSampleMin(ivar);
00967 xmax[ivar]=node->GetSampleMax(ivar);
00968 } else {
00969 xmin[ivar]=999;
00970 xmax[ivar]=-999;
00971
00972
00973 for (UInt_t iev=0; iev<nevents; iev++) {
00974
00975 Double_t result = fisherCoeff[fNvars];
00976 for (UInt_t jvar=0; jvar<fNvars; jvar++)
00977 result += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
00978 if (result > xmax[ivar]) xmax[ivar]=result;
00979 if (result < xmin[ivar]) xmin[ivar]=result;
00980 }
00981 }
00982 for (UInt_t ibin=0; ibin<nBins; ibin++) {
00983 nSelS[ivar][ibin]=0;
00984 nSelB[ivar][ibin]=0;
00985 nSelS_unWeighted[ivar][ibin]=0;
00986 nSelB_unWeighted[ivar][ibin]=0;
00987 target[ivar][ibin]=0;
00988 target2[ivar][ibin]=0;
00989 cutValues[ivar][ibin]=0;
00990 }
00991 }
00992
00993
00994 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
00995
00996 if ( useVariable[ivar] ) {
00997
00998
00999
01000
01001
01002
01003
01004
01005
01006
01007
01008
01009 Double_t istepSize =( xmax[ivar] - xmin[ivar] ) / Double_t(nBins);
01010 for (Int_t icut=0; icut<fNCuts; icut++) {
01011 cutValues[ivar][icut]=xmin[ivar]+(Double_t(icut+1))*istepSize;
01012 }
01013 }
01014 }
01015
01016 nTotS=0; nTotB=0;
01017 nTotS_unWeighted=0; nTotB_unWeighted=0;
01018 for (UInt_t iev=0; iev<nevents; iev++) {
01019
01020 Double_t eventWeight = eventSample[iev]->GetWeight();
01021 if (eventSample[iev]->GetClass() == fClass) {
01022 nTotS+=eventWeight;
01023 nTotS_unWeighted++;
01024 }
01025 else {
01026 nTotB+=eventWeight;
01027 nTotB_unWeighted++;
01028 }
01029
01030 Int_t iBin=-1;
01031 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
01032
01033
01034 if ( useVariable[ivar] ) {
01035 Double_t eventData;
01036 if (ivar < fNvars) eventData = eventSample[iev]->GetValue(ivar);
01037 else {
01038 eventData = fisherCoeff[fNvars];
01039 for (UInt_t jvar=0; jvar<fNvars; jvar++)
01040 eventData += fisherCoeff[jvar]*(eventSample[iev])->GetValue(jvar);
01041
01042 }
01043
01044 iBin = TMath::Min(Int_t(nBins-1),TMath::Max(0,int (nBins*(eventData-xmin[ivar])/(xmax[ivar]-xmin[ivar]) ) ));
01045 if (eventSample[iev]->GetClass() == fClass) {
01046 nSelS[ivar][iBin]+=eventWeight;
01047 nSelS_unWeighted[ivar][iBin]++;
01048 }
01049 else {
01050 nSelB[ivar][iBin]+=eventWeight;
01051 nSelB_unWeighted[ivar][iBin]++;
01052 }
01053 if (DoRegression()) {
01054 target[ivar][iBin] +=eventWeight*eventSample[iev]->GetTarget(0);
01055 target2[ivar][iBin]+=eventWeight*eventSample[iev]->GetTarget(0)*eventSample[iev]->GetTarget(0);
01056 }
01057 }
01058 }
01059 }
01060
01061 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
01062 if (useVariable[ivar]) {
01063 for (UInt_t ibin=1; ibin < nBins; ibin++) {
01064 nSelS[ivar][ibin]+=nSelS[ivar][ibin-1];
01065 nSelS_unWeighted[ivar][ibin]+=nSelS_unWeighted[ivar][ibin-1];
01066 nSelB[ivar][ibin]+=nSelB[ivar][ibin-1];
01067 nSelB_unWeighted[ivar][ibin]+=nSelB_unWeighted[ivar][ibin-1];
01068 if (DoRegression()) {
01069 target[ivar][ibin] +=target[ivar][ibin-1] ;
01070 target2[ivar][ibin]+=target2[ivar][ibin-1];
01071 }
01072 }
01073 if (nSelS_unWeighted[ivar][nBins-1] +nSelB_unWeighted[ivar][nBins-1] != eventSample.size()) {
01074 Log() << kFATAL << "Helge, you have a bug ....nSelS_unw..+nSelB_unw..= "
01075 << nSelS_unWeighted[ivar][nBins-1] +nSelB_unWeighted[ivar][nBins-1]
01076 << " while eventsample size = " << eventSample.size()
01077 << Endl;
01078 }
01079 double lastBins=nSelS[ivar][nBins-1] +nSelB[ivar][nBins-1];
01080 double totalSum=nTotS+nTotB;
01081 if (TMath::Abs(lastBins-totalSum)/totalSum>0.01) {
01082 Log() << kFATAL << "Helge, you have another bug ....nSelS+nSelB= "
01083 << lastBins
01084 << " while total number of events = " << totalSum
01085 << Endl;
01086 }
01087 }
01088 }
01089
01090
01091 for (UInt_t ivar=0; ivar < cNvars; ivar++) {
01092 if (useVariable[ivar]) {
01093 for (UInt_t iBin=0; iBin<nBins-1; iBin++) {
01094
01095
01096
01097
01098
01099
01100
01101
01102
01103
01104
01105 Double_t sl = nSelS_unWeighted[ivar][iBin];
01106 Double_t bl = nSelB_unWeighted[ivar][iBin];
01107 Double_t s = nTotS_unWeighted;
01108 Double_t b = nTotB_unWeighted;
01109 Double_t sr = s-sl;
01110 Double_t br = b-bl;
01111 if ( (sl+bl)>=fMinSize && (sr+br)>=fMinSize ) {
01112
01113 if (DoRegression()) {
01114 sepTmp = fRegType->GetSeparationGain(nSelS[ivar][iBin]+nSelB[ivar][iBin],
01115 target[ivar][iBin],target2[ivar][iBin],
01116 nTotS+nTotB,
01117 target[ivar][nBins-1],target2[ivar][nBins-1]);
01118 } else {
01119 sepTmp = fSepType->GetSeparationGain(nSelS[ivar][iBin], nSelB[ivar][iBin], nTotS, nTotB);
01120 }
01121 if (separationGain < sepTmp) {
01122 separationGain = sepTmp;
01123 mxVar = ivar;
01124 cutIndex = iBin;
01125 if (cutIndex >= fNCuts) Log()<<kFATAL<<"ibin for cut " << iBin << Endl;
01126 }
01127 }
01128 }
01129 }
01130 }
01131
01132 if (DoRegression()) {
01133 node->SetSeparationIndex(fRegType->GetSeparationIndex(nTotS+nTotB,target[0][nBins-1],target2[0][nBins-1]));
01134 node->SetResponse(target[0][nBins-1]/(nTotS+nTotB));
01135 node->SetRMS(TMath::Sqrt(target2[0][nBins-1]/(nTotS+nTotB) - target[0][nBins-1]/(nTotS+nTotB)*target[0][nBins-1]/(nTotS+nTotB)));
01136 }
01137 else {
01138 node->SetSeparationIndex(fSepType->GetSeparationIndex(nTotS,nTotB));
01139 }
01140 if (mxVar >= 0) {
01141 if (nSelS[mxVar][cutIndex]/nTotS > nSelB[mxVar][cutIndex]/nTotB) cutType=kTRUE;
01142 else cutType=kFALSE;
01143 cutValue = cutValues[mxVar][cutIndex];
01144
01145 node->SetSelector((UInt_t)mxVar);
01146 node->SetCutValue(cutValue);
01147 node->SetCutType(cutType);
01148 node->SetSeparationGain(separationGain);
01149 if (mxVar < (Int_t) fNvars){
01150 node->SetNFisherCoeff(0);
01151 fVariableImportance[mxVar] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
01152 }else{
01153
01154
01155
01156 node->SetNFisherCoeff(fNvars+1);
01157 for (UInt_t ivar=0; ivar<=fNvars; ivar++) {
01158 node->SetFisherCoeff(ivar,fisherCoeff[ivar]);
01159
01160 if (ivar<fNvars){
01161 fVariableImportance[ivar] += fisherCoeff[ivar]*separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB) ;
01162 }
01163 }
01164 }
01165 }
01166 else {
01167 separationGain = 0;
01168 }
01169
01170
01171 for (UInt_t i=0; i<cNvars; i++) {
01172 delete [] nSelS[i];
01173 delete [] nSelB[i];
01174 delete [] nSelS_unWeighted[i];
01175 delete [] nSelB_unWeighted[i];
01176 delete [] target[i];
01177 delete [] target2[i];
01178 delete [] cutValues[i];
01179 }
01180 delete [] nSelS;
01181 delete [] nSelB;
01182 delete [] nSelS_unWeighted;
01183 delete [] nSelB_unWeighted;
01184 delete [] target;
01185 delete [] target2;
01186 delete [] cutValues;
01187
01188 delete [] xmin;
01189 delete [] xmax;
01190
01191 delete [] useVariable;
01192 delete [] mapVariable;
01193
01194 return separationGain;
01195
01196 }
01197
01198
01199
01200
01201 std::vector<Double_t> TMVA::DecisionTree::GetFisherCoefficients(const EventList &eventSample, UInt_t nFisherVars, UInt_t *mapVarInFisher){
01202
01203
01204 std::vector<Double_t> fisherCoeff(fNvars+1);
01205
01206
01207
01208 TMatrixD* meanMatx = new TMatrixD( nFisherVars, 3 );
01209
01210
01211 TMatrixD* betw = new TMatrixD( nFisherVars, nFisherVars );
01212 TMatrixD* with = new TMatrixD( nFisherVars, nFisherVars );
01213 TMatrixD* cov = new TMatrixD( nFisherVars, nFisherVars );
01214
01215
01216
01217
01218
01219
01220 Double_t sumOfWeightsS = 0;
01221 Double_t sumOfWeightsB = 0;
01222
01223
01224
01225 Double_t* sumS = new Double_t[nFisherVars];
01226 Double_t* sumB = new Double_t[nFisherVars];
01227 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) { sumS[ivar] = sumB[ivar] = 0; }
01228
01229 UInt_t nevents = eventSample.size();
01230
01231 for (UInt_t ievt=0; ievt<nevents; ievt++) {
01232
01233
01234 const Event * ev = eventSample[ievt];
01235
01236
01237 Double_t weight = ev->GetWeight();
01238 if (ev->GetClass() == fClass) sumOfWeightsS += weight;
01239 else sumOfWeightsB += weight;
01240
01241 Double_t* sum = ev->GetClass() == fClass ? sumS : sumB;
01242 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) sum[ivar] += ev->GetValue( mapVarInFisher[ivar] )*weight;
01243 }
01244
01245 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
01246 (*meanMatx)( ivar, 2 ) = sumS[ivar];
01247 (*meanMatx)( ivar, 0 ) = sumS[ivar]/sumOfWeightsS;
01248
01249 (*meanMatx)( ivar, 2 ) += sumB[ivar];
01250 (*meanMatx)( ivar, 1 ) = sumB[ivar]/sumOfWeightsB;
01251
01252
01253 (*meanMatx)( ivar, 2 ) /= (sumOfWeightsS + sumOfWeightsB);
01254 }
01255 delete [] sumS;
01256 delete [] sumB;
01257
01258
01259
01260
01261
01262
01263 assert( sumOfWeightsS > 0 && sumOfWeightsB > 0 );
01264
01265
01266
01267 const Int_t nFisherVars2 = nFisherVars*nFisherVars;
01268 Double_t *sum2Sig = new Double_t[nFisherVars2];
01269 Double_t *sum2Bgd = new Double_t[nFisherVars2];
01270 Double_t *xval = new Double_t[nFisherVars2];
01271 memset(sum2Sig,0,nFisherVars2*sizeof(Double_t));
01272 memset(sum2Bgd,0,nFisherVars2*sizeof(Double_t));
01273
01274
01275 for (UInt_t ievt=0; ievt<nevents; ievt++) {
01276
01277
01278 const Event* ev = eventSample[ievt];
01279
01280 Double_t weight = ev->GetWeight();
01281
01282 for (UInt_t x=0; x<nFisherVars; x++) xval[x] = ev->GetValue( mapVarInFisher[x] );
01283 Int_t k=0;
01284 for (UInt_t x=0; x<nFisherVars; x++) {
01285 for (UInt_t y=0; y<nFisherVars; y++) {
01286 Double_t v = ( (xval[x] - (*meanMatx)(x, 0))*(xval[y] - (*meanMatx)(y, 0)) )*weight;
01287 if ( ev->GetClass() == fClass ) sum2Sig[k] += v;
01288 else sum2Bgd[k] += v;
01289 k++;
01290 }
01291 }
01292 }
01293 Int_t k=0;
01294 for (UInt_t x=0; x<nFisherVars; x++) {
01295 for (UInt_t y=0; y<nFisherVars; y++) {
01296 (*with)(x, y) = (sum2Sig[k] + sum2Bgd[k])/(sumOfWeightsS + sumOfWeightsB);
01297 k++;
01298 }
01299 }
01300
01301 delete [] sum2Sig;
01302 delete [] sum2Bgd;
01303 delete [] xval;
01304
01305
01306
01307
01308
01309
01310
01311 Double_t prodSig, prodBgd;
01312
01313 for (UInt_t x=0; x<nFisherVars; x++) {
01314 for (UInt_t y=0; y<nFisherVars; y++) {
01315
01316 prodSig = ( ((*meanMatx)(x, 0) - (*meanMatx)(x, 2))*
01317 ((*meanMatx)(y, 0) - (*meanMatx)(y, 2)) );
01318 prodBgd = ( ((*meanMatx)(x, 1) - (*meanMatx)(x, 2))*
01319 ((*meanMatx)(y, 1) - (*meanMatx)(y, 2)) );
01320
01321 (*betw)(x, y) = (sumOfWeightsS*prodSig + sumOfWeightsB*prodBgd) / (sumOfWeightsS + sumOfWeightsB);
01322 }
01323 }
01324
01325
01326
01327
01328 for (UInt_t x=0; x<nFisherVars; x++)
01329 for (UInt_t y=0; y<nFisherVars; y++)
01330 (*cov)(x, y) = (*with)(x, y) + (*betw)(x, y);
01331
01332
01333
01334
01335
01336
01337
01338
01339
01340 TMatrixD* theMat = with;
01341
01342
01343 TMatrixD invCov( *theMat );
01344 if ( TMath::Abs(invCov.Determinant()) < 10E-24 ) {
01345 Log() << kWARNING << "FisherCoeff matrix is almost singular with deterninant="
01346 << TMath::Abs(invCov.Determinant())
01347 << " did you use the variables that are linear combinations or highly correlated?"
01348 << Endl;
01349 }
01350 if ( TMath::Abs(invCov.Determinant()) < 10E-120 ) {
01351 Log() << kFATAL << "FisherCoeff matrix is singular with determinant="
01352 << TMath::Abs(invCov.Determinant())
01353 << " did you use the variables that are linear combinations?"
01354 << Endl;
01355 }
01356
01357 invCov.Invert();
01358
01359
01360 Double_t xfact = TMath::Sqrt( sumOfWeightsS*sumOfWeightsB ) / (sumOfWeightsS + sumOfWeightsB);
01361
01362
01363 std::vector<Double_t> diffMeans( nFisherVars );
01364
01365 for (UInt_t ivar=0; ivar<=fNvars; ivar++) fisherCoeff[ivar] = 0;
01366 for (UInt_t ivar=0; ivar<nFisherVars; ivar++) {
01367 for (UInt_t jvar=0; jvar<nFisherVars; jvar++) {
01368 Double_t d = (*meanMatx)(jvar, 0) - (*meanMatx)(jvar, 1);
01369 fisherCoeff[mapVarInFisher[ivar]] += invCov(ivar, jvar)*d;
01370 }
01371
01372
01373 fisherCoeff[mapVarInFisher[ivar]] *= xfact;
01374 }
01375
01376
01377 Double_t f0 = 0.0;
01378 for (UInt_t ivar=0; ivar<nFisherVars; ivar++){
01379 f0 += fisherCoeff[mapVarInFisher[ivar]]*((*meanMatx)(ivar, 0) + (*meanMatx)(ivar, 1));
01380 }
01381 f0 /= -2.0;
01382
01383 fisherCoeff[fNvars] = f0;
01384
01385 return fisherCoeff;
01386 }
01387
01388
01389 Double_t TMVA::DecisionTree::TrainNodeFull( const vector<TMVA::Event*> & eventSample,
01390 TMVA::DecisionTreeNode *node )
01391 {
01392
01393
01394
01395
01396 Double_t nTotS = 0.0, nTotB = 0.0;
01397 Int_t nTotS_unWeighted = 0, nTotB_unWeighted = 0;
01398
01399 vector<TMVA::BDTEventWrapper> bdtEventSample;
01400
01401
01402 vector<Double_t> lCutValue( fNvars, 0.0 );
01403 vector<Double_t> lSepGain( fNvars, -1.0e6 );
01404 vector<Char_t> lCutType( fNvars );
01405 lCutType.assign( fNvars, Char_t(kFALSE) );
01406
01407
01408
01409 for( vector<TMVA::Event*>::const_iterator it = eventSample.begin(); it != eventSample.end(); ++it ) {
01410 if((*it)->GetClass() == fClass) {
01411 nTotS += (*it)->GetWeight();
01412 ++nTotS_unWeighted;
01413 }
01414 else {
01415 nTotB += (*it)->GetWeight();
01416 ++nTotB_unWeighted;
01417 }
01418 bdtEventSample.push_back(TMVA::BDTEventWrapper(*it));
01419 }
01420
01421 vector<Char_t> useVariable(fNvars);
01422 useVariable.assign( fNvars, Char_t(kTRUE) );
01423
01424 for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar]=Char_t(kFALSE);
01425 if (fRandomisedTree) {
01426 if (fUseNvars ==0 ) {
01427
01428 fUseNvars = UInt_t(TMath::Sqrt(fNvars)+0.6);
01429 }
01430 Int_t nSelectedVars = 0;
01431 while (nSelectedVars < fUseNvars) {
01432 Double_t bla = fMyTrandom->Rndm()*fNvars;
01433 useVariable[Int_t (bla)] = Char_t(kTRUE);
01434 nSelectedVars = 0;
01435 for (UInt_t ivar=0; ivar < fNvars; ivar++) {
01436 if(useVariable[ivar] == Char_t(kTRUE)) nSelectedVars++;
01437 }
01438 }
01439 }
01440 else {
01441 for (UInt_t ivar=0; ivar < fNvars; ivar++) useVariable[ivar] = Char_t(kTRUE);
01442 }
01443
01444 for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
01445 if(!useVariable[ivar]) continue;
01446 TMVA::BDTEventWrapper::SetVarIndex(ivar);
01447 std::sort( bdtEventSample.begin(),bdtEventSample.end() );
01448
01449 Double_t bkgWeightCtr = 0.0, sigWeightCtr = 0.0;
01450 vector<TMVA::BDTEventWrapper>::iterator it = bdtEventSample.begin(), it_end = bdtEventSample.end();
01451 for( ; it != it_end; ++it ) {
01452 if((**it)->GetClass() == fClass )
01453 sigWeightCtr += (**it)->GetWeight();
01454 else
01455 bkgWeightCtr += (**it)->GetWeight();
01456
01457 it->SetCumulativeWeight(false,bkgWeightCtr);
01458 it->SetCumulativeWeight(true,sigWeightCtr);
01459 }
01460
01461 const Double_t fPMin = 1.0e-6;
01462 Bool_t cutType = kFALSE;
01463 Long64_t index = 0;
01464 Double_t separationGain = -1.0, sepTmp = 0.0, cutValue = 0.0, dVal = 0.0, norm = 0.0;
01465
01466 for( it = bdtEventSample.begin(); it != it_end; ++it ) {
01467 if( index == 0 ) { ++index; continue; }
01468 if( *(*it) == NULL ) {
01469 Log() << kFATAL << "In TrainNodeFull(): have a null event! Where index="
01470 << index << ", and parent node=" << node->GetParent() << Endl;
01471 break;
01472 }
01473 dVal = bdtEventSample[index].GetVal() - bdtEventSample[index-1].GetVal();
01474 norm = TMath::Abs(bdtEventSample[index].GetVal() + bdtEventSample[index-1].GetVal());
01475
01476
01477 if( index >= fMinSize && (nTotS_unWeighted + nTotB_unWeighted) - index >= fMinSize && TMath::Abs(dVal/(0.5*norm + 1)) > fPMin ) {
01478 sepTmp = fSepType->GetSeparationGain( it->GetCumulativeWeight(true), it->GetCumulativeWeight(false), sigWeightCtr, bkgWeightCtr );
01479 if( sepTmp > separationGain ) {
01480 separationGain = sepTmp;
01481 cutValue = it->GetVal() - 0.5*dVal;
01482 Double_t nSelS = it->GetCumulativeWeight(true);
01483 Double_t nSelB = it->GetCumulativeWeight(false);
01484
01485
01486 if( nSelS/sigWeightCtr > nSelB/bkgWeightCtr ) cutType = kTRUE;
01487 else cutType = kFALSE;
01488 }
01489 }
01490 ++index;
01491 }
01492 lCutType[ivar] = Char_t(cutType);
01493 lCutValue[ivar] = cutValue;
01494 lSepGain[ivar] = separationGain;
01495 }
01496
01497 Double_t separationGain = -1.0;
01498 Int_t iVarIndex = -1;
01499 for( UInt_t ivar = 0; ivar < fNvars; ivar++ ) {
01500 if( lSepGain[ivar] > separationGain ) {
01501 iVarIndex = ivar;
01502 separationGain = lSepGain[ivar];
01503 }
01504 }
01505
01506 if(iVarIndex >= 0) {
01507 node->SetSelector(iVarIndex);
01508 node->SetCutValue(lCutValue[iVarIndex]);
01509 node->SetSeparationGain(lSepGain[iVarIndex]);
01510 node->SetCutType(lCutType[iVarIndex]);
01511
01512 fVariableImportance[iVarIndex] += separationGain*separationGain * (nTotS+nTotB) * (nTotS+nTotB);
01513 }
01514 else {
01515 separationGain = 0.0;
01516 }
01517
01518 return separationGain;
01519 }
01520
01521
01522 TMVA::DecisionTreeNode* TMVA::DecisionTree::GetEventNode(const TMVA::Event & e) const
01523 {
01524
01525
01526
01527 TMVA::DecisionTreeNode *current = (TMVA::DecisionTreeNode*)this->GetRoot();
01528 while(current->GetNodeType() == 0) {
01529 current = (current->GoesRight(e)) ?
01530 (TMVA::DecisionTreeNode*)current->GetRight() :
01531 (TMVA::DecisionTreeNode*)current->GetLeft();
01532 }
01533 return current;
01534 }
01535
01536
01537 Double_t TMVA::DecisionTree::CheckEvent( const TMVA::Event & e, Bool_t UseYesNoLeaf ) const
01538 {
01539
01540
01541
01542
01543
01544 TMVA::DecisionTreeNode *current = this->GetRoot();
01545 if (!current)
01546 Log() << kFATAL << "CheckEvent: started with undefined ROOT node" <<Endl;
01547
01548 while (current->GetNodeType() == 0) {
01549 current = (current->GoesRight(e)) ?
01550 current->GetRight() :
01551 current->GetLeft();
01552 if (!current) {
01553 Log() << kFATAL << "DT::CheckEvent: inconsistent tree structure" <<Endl;
01554 }
01555
01556 }
01557
01558 if ( DoRegression() ){
01559 return current->GetResponse();
01560 }
01561 else {
01562 if (UseYesNoLeaf) return Double_t ( current->GetNodeType() );
01563 else return current->GetPurity();
01564 }
01565 }
01566
01567
01568 Double_t TMVA::DecisionTree::SamplePurity( vector<TMVA::Event*> eventSample )
01569 {
01570
01571
01572 Double_t sumsig=0, sumbkg=0, sumtot=0;
01573 for (UInt_t ievt=0; ievt<eventSample.size(); ievt++) {
01574 if (eventSample[ievt]->GetClass() != fClass) sumbkg+=eventSample[ievt]->GetWeight();
01575 else sumsig+=eventSample[ievt]->GetWeight();
01576 sumtot+=eventSample[ievt]->GetWeight();
01577 }
01578
01579 if (sumtot!= (sumsig+sumbkg)){
01580 Log() << kFATAL << "<SamplePurity> sumtot != sumsig+sumbkg"
01581 << sumtot << " " << sumsig << " " << sumbkg << Endl;
01582 }
01583 if (sumtot>0) return sumsig/(sumsig + sumbkg);
01584 else return -1;
01585 }
01586
01587
01588 vector< Double_t > TMVA::DecisionTree::GetVariableImportance()
01589 {
01590
01591
01592
01593
01594
01595 vector<Double_t> relativeImportance(fNvars);
01596 Double_t sum=0;
01597 for (UInt_t i=0; i< fNvars; i++) {
01598 sum += fVariableImportance[i];
01599 relativeImportance[i] = fVariableImportance[i];
01600 }
01601
01602 for (UInt_t i=0; i< fNvars; i++) {
01603 if (sum > std::numeric_limits<double>::epsilon())
01604 relativeImportance[i] /= sum;
01605 else
01606 relativeImportance[i] = 0;
01607 }
01608 return relativeImportance;
01609 }
01610
01611
01612 Double_t TMVA::DecisionTree::GetVariableImportance( UInt_t ivar )
01613 {
01614
01615
01616 vector<Double_t> relativeImportance = this->GetVariableImportance();
01617 if (ivar < fNvars) return relativeImportance[ivar];
01618 else {
01619 Log() << kFATAL << "<GetVariableImportance>" << Endl
01620 << "--- ivar = " << ivar << " is out of range " << Endl;
01621 }
01622
01623 return -1;
01624 }
01625