BinarySearchTree.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: BinarySearchTree.cxx 37986 2011-02-04 21:42:15Z pcanal $    
00002 // Author: Andreas Hoecker, Joerg Stelzer, Helge Voss, Kai Voss 
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : BinarySearchTree                                                      *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header file for description)                          *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Joerg Stelzer   <stelzer@cern.ch>        - DESY, Germany                  *
00016  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00017  *      Kai Voss        <Kai.Voss@cern.ch>       - U. of Victoria, Canada         *
00018  *                                                                                *
00019  * Copyright (c) 2005:                                                            *
00020  *      CERN, Switzerland                                                         * 
00021  *      U. of Victoria, Canada                                                    * 
00022  *      MPI-K Heidelberg, Germany                                                 * 
00023  *      LAPP, Annecy, France                                                      *
00024  *                                                                                *
00025  * Redistribution and use in source and binary forms, with or without             *
00026  * modification, are permitted according to the terms listed in LICENSE           *
00027  * (http://tmva.sourceforge.net/LICENSE)                                          *
00028  *                                                                                *
00029  **********************************************************************************/
00030 
00031 //////////////////////////////////////////////////////////////////////////
00032 //                                                                      //
00033 // BinarySearchTree                                                     //
00034 //                                                                      //
00035 // A simple Binary search tree including a volume search method         //
00036 //                                                                      //
00037 //////////////////////////////////////////////////////////////////////////
00038 
00039 #include <stdexcept>
00040 #include <cstdlib>
00041 #include <queue>
00042 #include <algorithm>
00043 
00044 #if ROOT_VERSION_CODE >= 364802
00045 #ifndef ROOT_TMathBase
00046 #include "TMathBase.h"
00047 #endif
00048 #else
00049 #ifndef ROOT_TMath
00050 #include "TMath.h"
00051 #endif
00052 #endif
00053 #include "TMatrixDBase.h"
00054 #include "TObjString.h"
00055 #include "TTree.h"
00056 
00057 #ifndef ROOT_TMVA_MsgLogger
00058 #include "TMVA/MsgLogger.h"
00059 #endif
00060 #ifndef ROOT_TMVA_MethodBase
00061 #include "TMVA/MethodBase.h"
00062 #endif
00063 #ifndef ROOT_TMVA_Tools
00064 #include "TMVA/Tools.h"
00065 #endif
00066 #ifndef ROOT_TMVA_DataSet
00067 #include "TMVA/DataSet.h"
00068 #endif
00069 #ifndef ROOT_TMVA_Event
00070 #include "TMVA/Event.h"
00071 #endif
00072 #ifndef ROOT_TMVA_BinarySearchTree
00073 #include "TMVA/BinarySearchTree.h"
00074 #endif
00075 
00076 ClassImp(TMVA::BinarySearchTree)
00077 
00078 //_______________________________________________________________________
00079 TMVA::BinarySearchTree::BinarySearchTree( void ) :
00080    BinaryTree(),
00081    fPeriod      ( 1 ),
00082    fCurrentDepth( 0 ),
00083    fStatisticsIsValid( kFALSE ),
00084    fSumOfWeights( 0 ),
00085    fCanNormalize( kFALSE )
00086 {
00087    // default constructor
00088    fNEventsW[0]=fNEventsW[1]=0.;
00089 }
00090 
00091 //_______________________________________________________________________
00092 TMVA::BinarySearchTree::BinarySearchTree( const BinarySearchTree &b)
00093    : BinaryTree(), 
00094      fPeriod      ( b.fPeriod ),
00095      fCurrentDepth( 0 ),
00096      fStatisticsIsValid( kFALSE ),
00097      fSumOfWeights( b.fSumOfWeights ),
00098      fCanNormalize( kFALSE )
00099 {
00100    // copy constructor that creates a true copy, i.e. a completely independent tree 
00101    fNEventsW[0]=fNEventsW[1]=0.;
00102    Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
00103 }
00104 
00105 //_______________________________________________________________________
00106 TMVA::BinarySearchTree::~BinarySearchTree( void ) 
00107 {
00108    // destructor
00109 
00110    for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
00111        pIt != fNormalizeTreeTable.end(); pIt++) {
00112       delete pIt->second;
00113    }
00114 }
00115 
00116 //_______________________________________________________________________
00117 TMVA::BinarySearchTree* TMVA::BinarySearchTree::CreateFromXML(void* node, UInt_t tmva_Version_Code ) {
00118    // re-create a new tree (decision tree or search tree) from XML
00119    std::string type("");
00120    gTools().ReadAttr(node,"type", type);
00121    BinarySearchTree* bt = new BinarySearchTree();
00122    bt->ReadXML( node, tmva_Version_Code );
00123    return bt;
00124 }
00125 
00126 //_______________________________________________________________________
00127 void TMVA::BinarySearchTree::Insert( const Event* event ) 
00128 {
00129    // insert a new "event" in the binary tree
00130    fCurrentDepth=0;
00131    fStatisticsIsValid = kFALSE;
00132 
00133    if (this->GetRoot() == NULL) {           // If the list is empty...
00134       this->SetRoot( new BinarySearchTreeNode(event)); //Make the new node the root.
00135       // have to use "s" for start as "r" for "root" would be the same as "r" for "right"
00136       this->GetRoot()->SetPos('s'); 
00137       this->GetRoot()->SetDepth(0);
00138       fNNodes = 1;
00139       fSumOfWeights = event->GetWeight();
00140       ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
00141       this->SetPeriode(event->GetNVariables());
00142    }
00143    else {
00144       // sanity check:
00145       if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
00146          Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
00147                << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
00148                << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
00149       }
00150       // insert a new node at the propper position  
00151       this->Insert(event, this->GetRoot()); 
00152    }
00153 
00154    // normalise the tree to speed up searches
00155    if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
00156 }
00157 
00158 //_______________________________________________________________________
00159 void TMVA::BinarySearchTree::Insert( const Event *event, 
00160                                      Node *node ) 
00161 {
00162    // private internal function to insert a event (node) at the proper position
00163    fCurrentDepth++;
00164    fStatisticsIsValid = kFALSE;
00165 
00166    if (node->GoesLeft(*event)){    // If the adding item is less than the current node's data...
00167       if (node->GetLeft() != NULL){            // If there is a left node...
00168          // Add the new event to the left node
00169          this->Insert(event, node->GetLeft());
00170       } 
00171       else {                             // If there is not a left node...
00172          // Make the new node for the new event
00173          BinarySearchTreeNode* current = new BinarySearchTreeNode(event); 
00174          fNNodes++;
00175          fSumOfWeights += event->GetWeight();
00176          current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
00177          current->SetParent(node);          // Set the new node's previous node.
00178          current->SetPos('l');
00179          current->SetDepth( node->GetDepth() + 1 );
00180          node->SetLeft(current);            // Make it the left node of the current one.
00181       }  
00182    } 
00183    else if (node->GoesRight(*event)) { // If the adding item is less than or equal to the current node's data...
00184       if (node->GetRight() != NULL) {              // If there is a right node...
00185          // Add the new node to it.
00186          this->Insert(event, node->GetRight()); 
00187       } 
00188       else {                                 // If there is not a right node...
00189          // Make the new node.
00190          BinarySearchTreeNode* current = new BinarySearchTreeNode(event);   
00191          fNNodes++;
00192          fSumOfWeights += event->GetWeight();
00193          current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
00194          current->SetParent(node);              // Set the new node's previous node.
00195          current->SetPos('r');
00196          current->SetDepth( node->GetDepth() + 1 );
00197          node->SetRight(current);               // Make it the left node of the current one.
00198       }
00199    } 
00200    else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
00201 }
00202 
00203 //_______________________________________________________________________
00204 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search( Event* event ) const 
00205 { 
00206    //search the tree to find the node matching "event"
00207    return this->Search( event, this->GetRoot() );
00208 }
00209 
00210 //_______________________________________________________________________
00211 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search(Event* event, Node* node) const 
00212 { 
00213    // Private, recursive, function for searching.
00214    if (node != NULL) {               // If the node is not NULL...
00215       // If we have found the node...
00216       if (((BinarySearchTreeNode*)(node))->EqualsMe(*event)) 
00217          return (BinarySearchTreeNode*)node;                  // Return it
00218       if (node->GoesLeft(*event))      // If the node's data is greater than the search item...
00219          return this->Search(event, node->GetLeft());  //Search the left node.
00220       else                          //If the node's data is less than the search item...
00221          return this->Search(event, node->GetRight()); //Search the right node.
00222    }
00223    else return NULL; //If the node is NULL, return NULL.
00224 }
00225 
00226 //_______________________________________________________________________
00227 Double_t TMVA::BinarySearchTree::GetSumOfWeights( void ) const
00228 {
00229    //return the sum of event (node) weights
00230    if (fSumOfWeights <= 0) {
00231       Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
00232             << " I call CalcStatistics which hopefully fixes things" 
00233             << Endl;
00234    }
00235    if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
00236 
00237    return fSumOfWeights;
00238 }
00239 
00240 //_______________________________________________________________________
00241 Double_t TMVA::BinarySearchTree::GetSumOfWeights( Int_t theType ) const
00242 {
00243    //return the sum of event (node) weights
00244    if (fSumOfWeights <= 0) {
00245       Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
00246               << " I call CalcStatistics which hopefully fixes things" 
00247               << Endl;
00248    }
00249    if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
00250 
00251    return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1  ];
00252 }
00253 
00254 //_______________________________________________________________________
00255 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars, 
00256                                        Int_t theType )
00257 {
00258    // create the search tree from the event collection 
00259    // using ONLY the variables specified in "theVars"
00260    fPeriod = theVars.size();
00261    return Fill(events, theType);
00262 }
00263 
00264 //_______________________________________________________________________
00265 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
00266 {
00267    // create the search tree from the events in a TTree
00268    // using ALL the variables specified included in the Event
00269    UInt_t n=events.size();
00270   
00271    UInt_t nevents = 0;
00272    if (fSumOfWeights != 0) {
00273       Log() << kWARNING 
00274               << "You are filling a search three that is not empty.. "
00275               << " do you know what you are doing?"
00276               << Endl;
00277    }
00278    for (UInt_t ievt=0; ievt<n; ievt++) {
00279       // insert event into binary tree
00280       if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
00281          this->Insert( events[ievt] );
00282          nevents++;
00283          fSumOfWeights += events[ievt]->GetWeight();
00284       }
00285    } // end of event loop
00286    CalcStatistics();
00287 
00288    return fSumOfWeights;
00289 }
00290 
00291 //_______________________________________________________________________
00292 void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound, 
00293                                              std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound, 
00294                                              UInt_t actDim )
00295 {
00296 
00297    // normalises the binary-search tree to reduce the branch length and hence speed up the 
00298    // search procedure (on average)
00299    if (leftBound == rightBound) return;
00300    
00301    if (actDim == fPeriod)  actDim = 0;
00302    for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; i++) {
00303       i->first = i->second->GetValue( actDim );
00304    }
00305    
00306    std::sort( leftBound, rightBound );
00307    
00308    std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp  = leftBound;
00309    std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
00310   
00311    // meet in the middle
00312    while (true) {
00313       rightTemp--; 
00314       if (rightTemp == leftTemp ) {
00315          break;
00316       }
00317       leftTemp++;  
00318       if (leftTemp  == rightTemp) {
00319          break;
00320       }
00321    }
00322   
00323    std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid     = leftTemp;
00324    std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
00325 
00326    if (mid!=leftBound) midTemp--;
00327 
00328    while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim ))  {
00329       mid--; 
00330       midTemp--;
00331    }
00332 
00333    Insert( mid->second );
00334 
00335    //    Print(cout);
00336    //    cout << endl << endl;
00337 
00338    NormalizeTree( leftBound, mid, actDim+1 );
00339    mid++;
00340    //    Print(cout);
00341    //    cout << endl << endl;
00342    NormalizeTree( mid, rightBound, actDim+1 );
00343 
00344 
00345    return;  
00346 }
00347 
00348 //_______________________________________________________________________
00349 void TMVA::BinarySearchTree::NormalizeTree()
00350 {
00351    // Normalisation of tree
00352    SetNormalize( kFALSE );
00353    Clear( NULL );
00354    this->SetRoot(NULL);
00355    NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 ); 
00356 }
00357 
00358 //_______________________________________________________________________
00359 void TMVA::BinarySearchTree::Clear( Node* n )
00360 {
00361    // clear nodes
00362    BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
00363 
00364    if (currentNode->GetLeft()  != 0) Clear( currentNode->GetLeft()  );
00365    if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
00366    
00367    if (n != NULL) delete n;
00368 
00369    return;
00370 }
00371 
00372 //_______________________________________________________________________
00373 Double_t TMVA::BinarySearchTree::SearchVolume( Volume* volume, 
00374                                                std::vector<const BinarySearchTreeNode*>* events )
00375 {
00376    // search the whole tree and add up all weigths of events that 
00377    // lie within the given voluem
00378    return SearchVolume( this->GetRoot(), volume, 0, events );
00379 }
00380 
00381 //_______________________________________________________________________
00382 Double_t TMVA::BinarySearchTree::SearchVolume( Node* t, Volume* volume, Int_t depth, 
00383                                                std::vector<const BinarySearchTreeNode*>* events )
00384 {
00385    // recursively walk through the daughter nodes and add up all weigths of events that 
00386    // lie within the given volume
00387 
00388    if (t==NULL) return 0;  // Are we at an outer leave?
00389    
00390    BinarySearchTreeNode* st = (BinarySearchTreeNode*)t;
00391 
00392    Double_t count = 0.0;
00393    if (InVolume( st->GetEventV(), volume )) {
00394       count += st->GetWeight();
00395       if (NULL != events) events->push_back( st );
00396    }
00397    if (st->GetLeft()==NULL && st->GetRight()==NULL) {
00398       
00399       return count;  // Are we at an outer leave?
00400    }
00401 
00402    Bool_t tl, tr;
00403    Int_t  d = depth%this->GetPeriode();
00404    if (d != st->GetSelector()) {
00405       Log() << kFATAL << "<SearchVolume> selector in Searchvolume " 
00406               << d << " != " << "node "<< st->GetSelector() << Endl;
00407    }
00408    tl = (*(volume->fLower))[d] <  st->GetEventV()[d];  // Should we descend left?
00409    tr = (*(volume->fUpper))[d] >= st->GetEventV()[d];  // Should we descend right?
00410 
00411    if (tl) count += SearchVolume( st->GetLeft(),  volume, (depth+1), events );
00412    if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
00413 
00414    return count;
00415 }
00416 
00417 Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const 
00418 {
00419    // test if the data points are in the given volume
00420 
00421    Bool_t result = false;
00422    for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
00423       result = ( (*(volume->fLower))[ivar] <  event[ivar] &&
00424                  (*(volume->fUpper))[ivar] >= event[ivar] );
00425       if (!result) break;
00426    }
00427    return result;
00428 }
00429 
00430 //_______________________________________________________________________
00431 void TMVA::BinarySearchTree::CalcStatistics( Node* n )
00432 {
00433    // calculate basic statistics (mean, rms for each variable)
00434    if (fStatisticsIsValid) return;
00435 
00436    BinarySearchTreeNode * currentNode = (BinarySearchTreeNode*)n;
00437 
00438    // default, start at the tree top, then descend recursively
00439    if (n == NULL) {
00440       fSumOfWeights = 0;
00441       for (Int_t sb=0; sb<2; sb++) {
00442          fNEventsW[sb]  = 0;
00443          fMeans[sb]     = std::vector<Float_t>(fPeriod);
00444          fRMS[sb]       = std::vector<Float_t>(fPeriod);
00445          fMin[sb]       = std::vector<Float_t>(fPeriod);
00446          fMax[sb]       = std::vector<Float_t>(fPeriod);
00447          fSum[sb]       = std::vector<Double_t>(fPeriod);
00448          fSumSq[sb]     = std::vector<Double_t>(fPeriod);
00449          for (UInt_t j=0; j<fPeriod; j++) {
00450             fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
00451             fMin[sb][j] =  FLT_MAX;
00452             fMax[sb][j] = -FLT_MAX; 
00453          }
00454       }
00455       currentNode = (BinarySearchTreeNode*) this->GetRoot();
00456       if (currentNode == NULL) return; // no root-node
00457    }
00458       
00459    const std::vector<Float_t> & evtVec = currentNode->GetEventV();
00460    Double_t                     weight = currentNode->GetWeight();
00461 //    Int_t                        type   = currentNode->IsSignal(); 
00462    Int_t                        type   = currentNode->IsSignal() ? 0 : 1; 
00463    fNEventsW[type] += weight;
00464    fSumOfWeights   += weight;
00465 
00466    for (UInt_t j=0; j<fPeriod; j++) {
00467       Float_t val = evtVec[j];
00468       fSum[type][j]   += val*weight;
00469       fSumSq[type][j] += val*val*weight;
00470       if (val < fMin[type][j]) fMin[type][j] = val; 
00471       if (val > fMax[type][j]) fMax[type][j] = val; 
00472    }
00473 
00474    if ( (currentNode->GetLeft()  != NULL) ) CalcStatistics( currentNode->GetLeft()  ); 
00475    if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() ); 
00476 
00477    if (n == NULL) { // i.e. the root node
00478       for (Int_t sb=0; sb<2; sb++) {
00479          for (UInt_t j=0; j<fPeriod; j++) {
00480             if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
00481             fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
00482             fRMS[sb][j]   = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
00483          }
00484       }
00485       fStatisticsIsValid = kTRUE;
00486    }
00487    
00488    return;
00489 }
00490 
00491 Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
00492                                                         Int_t max_points )
00493 {
00494    // recursively walk through the daughter nodes and add up all weigths of events that 
00495    // lie within the given volume a maximum number of events can be given
00496    if (this->GetRoot() == NULL) return 0;  // Are we at an outer leave?
00497 
00498    std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
00499    std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
00500    queue.push( st );
00501 
00502    Int_t count = 0;
00503    
00504    while ( !queue.empty() ) {
00505       st = queue.front(); queue.pop();
00506       
00507       if (count == max_points)
00508          return count;
00509 
00510       if (InVolume( st.first->GetEventV(), volume )) {
00511          count++;
00512          if (NULL != events) events->push_back( st.first );
00513       }
00514 
00515       Bool_t tl, tr;
00516       Int_t d = st.second;
00517       if ( d == Int_t(this->GetPeriode()) ) d = 0;
00518 
00519       if (d != st.first->GetSelector()) {
00520          Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
00521                  << d << " != " << "node "<< st.first->GetSelector() << Endl;
00522       }
00523 
00524       tl = (*(volume->fLower))[d] <  st.first->GetEventV()[d] && st.first->GetLeft()  != NULL;  // Should we descend left?
00525       tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL;  // Should we descend right?
00526 
00527       if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
00528       if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
00529    }
00530 
00531    return count;
00532 }

Generated on Tue Jul 5 15:16:40 2011 for ROOT_528-00b_version by  doxygen 1.5.1