NodekNN.h

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: NodekNN.h 31458 2009-11-30 13:58:20Z stelzer $
00002 // Author: Rustem Ospanov 
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : Node                                                                  *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      kd-tree (binary tree) template                                            *
00012  *                                                                                *
00013  * Author:                                                                        *
00014  *      Rustem Ospanov <rustem@fnal.gov> - U. of Texas at Austin, USA             *
00015  *                                                                                *
00016  * Copyright (c) 2007:                                                            *
00017  *      CERN, Switzerland                                                         * 
00018  *      MPI-K Heidelberg, Germany                                                 * 
00019  *      U. of Texas at Austin, USA                                                *
00020  *                                                                                *
00021  * Redistribution and use in source and binary forms, with or without             *
00022  * modification, are permitted according to the terms listed in LICENSE           *
00023  * (http://tmva.sourceforge.net/LICENSE)                                          *
00024  **********************************************************************************/
00025 
00026 #ifndef ROOT_TMVA_NodekNN
00027 #define ROOT_TMVA_NodekNN
00028 
00029 // C++
00030 #include <list>
00031 #include <string>
00032 #include <iostream>
00033 
00034 // ROOT
00035 #ifndef ROOT_Rtypes
00036 #include "Rtypes.h"
00037 #endif
00038 
00039 //////////////////////////////////////////////////////////////////////////
00040 //                                                                      //
00041 // kNN::Node                                                            //
00042 //                                                                      //
00043 // This file contains binary tree and global function template          //
00044 // that searches tree for k-nearest neigbors                            //
00045 //                                                                      //
00046 // Node class template parameter T has to provide these functions:      //
00047 //   rtype GetVar(UInt_t) const;                                        //
00048 //   - rtype is any type convertible to Float_t                         //
00049 //   UInt_t GetNVar(void) const;                                        //
00050 //   rtype GetWeight(void) const;                                       //
00051 //   - rtype is any type convertible to Double_t                        //
00052 //                                                                      //
00053 // Find function template parameter T has to provide these functions:   //
00054 // (in addition to above requirements)                                  //
00055 //   rtype GetDist(Float_t, UInt_t) const;                              //
00056 //   - rtype is any type convertible to Float_t                         //
00057 //   rtype GetDist(const T &) const;                                    //
00058 //   - rtype is any type convertible to Float_t                         //
00059 //                                                                      //
00060 //   where T::GetDist(Float_t, UInt_t) <= T::GetDist(const T &)         //
00061 //   for any pair of events and any variable number for these events    //
00062 //                                                                      //
00063 //////////////////////////////////////////////////////////////////////////
00064 
00065 namespace TMVA
00066 {
00067    namespace kNN
00068    {
00069       template <class T>
00070       class Node
00071       {
00072 
00073       public:
00074       
00075          Node(const Node *parent, const T &event, Int_t mod);
00076          ~Node();
00077 
00078          const Node* Add(const T &event, UInt_t depth);
00079       
00080          void SetNodeL(Node *node);
00081          void SetNodeR(Node *node);
00082       
00083          const T& GetEvent() const;
00084 
00085          const Node* GetNodeL() const;
00086          const Node* GetNodeR() const;
00087          const Node* GetNodeP() const;
00088       
00089          Double_t GetWeight() const;
00090 
00091          Float_t GetVarDis() const;
00092          Float_t GetVarMin() const;
00093          Float_t GetVarMax() const;
00094 
00095          UInt_t GetMod() const;
00096 
00097          void Print() const;
00098          void Print(std::ostream& os, const std::string &offset = "") const;
00099 
00100       private: 
00101 
00102          // these methods are private and not implemented by design
00103          // use provided public constructor for all uses of this template class
00104          Node();
00105          Node(const Node &);
00106          const Node& operator=(const Node &);
00107 
00108       private:
00109 
00110          const Node* fNodeP;
00111       
00112          Node* fNodeL;
00113          Node* fNodeR;      
00114       
00115          const T fEvent;
00116       
00117          const Float_t fVarDis;
00118 
00119          Float_t fVarMin;
00120          Float_t fVarMax;
00121 
00122          const UInt_t fMod;
00123       };
00124 
00125       // recursive search for k-nearest neighbor: k = nfind 
00126       template<class T>
00127       UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
00128                         const Node<T> *node, const T &event, UInt_t nfind);
00129 
00130       // recursive search for k-nearest neighbor
00131       // find k events with sum of event weights >= nfind
00132       template<class T>
00133       UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
00134                   const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
00135 
00136       // recursively travel upward until root node is reached
00137       template <class T>
00138       UInt_t Depth(const Node<T> *node);
00139 
00140       // prInt_t node content and content of its children
00141       //template <class T>
00142       //std::ostream& operator<<(std::ostream& os, const Node<T> &node);
00143 
00144       // 
00145       // Inlined functions for Node template
00146       //
00147       template <class T>
00148       inline void Node<T>::SetNodeL(Node<T> *node)
00149       {
00150          fNodeL = node;
00151       }
00152 
00153       template <class T>
00154       inline void Node<T>::SetNodeR(Node<T> *node)
00155       {
00156          fNodeR = node;
00157       }
00158 
00159       template <class T>
00160       inline const T& Node<T>::GetEvent() const
00161       {
00162          return fEvent;
00163       }
00164 
00165       template <class T>
00166       inline const Node<T>* Node<T>::GetNodeL() const
00167       {
00168          return fNodeL;
00169       }
00170 
00171       template <class T>
00172       inline const Node<T>* Node<T>::GetNodeR() const
00173       {
00174          return fNodeR;
00175       }
00176 
00177       template <class T>
00178       inline const Node<T>* Node<T>::GetNodeP() const
00179       {
00180          return fNodeP;
00181       }
00182 
00183       template <class T>
00184       inline Double_t Node<T>::GetWeight() const
00185       {
00186          return fEvent.GetWeight();
00187       }
00188 
00189       template <class T>
00190       inline Float_t Node<T>::GetVarDis() const
00191       {
00192          return fVarDis;
00193       }
00194 
00195       template <class T>
00196       inline Float_t Node<T>::GetVarMin() const
00197       {
00198          return fVarMin;
00199       }
00200 
00201       template <class T>
00202       inline Float_t Node<T>::GetVarMax() const
00203       {
00204          return fVarMax;
00205       }
00206 
00207       template <class T>
00208       inline UInt_t Node<T>::GetMod() const
00209       {
00210          return fMod;
00211       }
00212 
00213       // 
00214       // Inlined global function(s)
00215       //
00216       template <class T>
00217       inline UInt_t Depth(const Node<T> *node)
00218       {
00219          if (!node) return 0;
00220          else return Depth(node->GetNodeP()) + 1;
00221       }
00222 
00223    } // end of kNN namespace
00224 } // end of TMVA namespace
00225 
00226 //-------------------------------------------------------------------------------------------
00227 template<class T>
00228 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod) 
00229    :fNodeP(parent),
00230     fNodeL(0),
00231     fNodeR(0),
00232     fEvent(event),
00233     fVarDis(event.GetVar(mod)),
00234     fVarMin(fVarDis),
00235     fVarMax(fVarDis),
00236     fMod(mod)
00237 {}
00238 
00239 //-------------------------------------------------------------------------------------------
00240 template<class T>
00241 TMVA::kNN::Node<T>::~Node()
00242 {
00243    if (fNodeL) delete fNodeL;
00244    if (fNodeR) delete fNodeR;
00245 }
00246 
00247 //-------------------------------------------------------------------------------------------
00248 template<class T>
00249 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
00250 {
00251    // This is Node member function that adds a new node to a binary tree.
00252    // each node contains maximum and minimum values of splitting variable
00253    // left or right nodes are added based on value of splitting variable
00254    
00255    assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
00256    
00257    const Float_t value = event.GetVar(fMod);
00258    
00259    fVarMin = std::min(fVarMin, value);
00260    fVarMax = std::max(fVarMax, value);
00261    
00262    Node<T> *node = 0;
00263    if (value < fVarDis) {
00264       if (fNodeL)
00265          {
00266             return fNodeL->Add(event, depth + 1);
00267          }
00268       else {
00269          fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
00270          node = fNodeL;
00271       }
00272    }
00273    else {
00274       if (fNodeR) {
00275          return fNodeR->Add(event, depth + 1);
00276       }
00277       else {
00278          fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
00279          node = fNodeR;
00280       }      
00281    }
00282    
00283    return node;
00284 }
00285    
00286 //-------------------------------------------------------------------------------------------
00287 template<class T>
00288 void TMVA::kNN::Node<T>::Print() const
00289 {
00290    Print(std::cout);
00291 }
00292    
00293 //-------------------------------------------------------------------------------------------
00294 template<class T>
00295 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
00296 {
00297    os << offset << "-----------------------------------------------------------" << std::endl;
00298    os << offset << "Node: mod " << fMod 
00299       << " at " << fVarDis 
00300       << " with weight: " << GetWeight() << std::endl
00301       << offset << fEvent;
00302    
00303    if (fNodeL) {
00304       os << offset << "Has left node " << std::endl;
00305    }
00306    if (fNodeR) {
00307       os << offset << "Has right node" << std::endl;
00308    }
00309    
00310    if (fNodeL) {
00311       os << offset << "PrInt_t left node " << std::endl;
00312       fNodeL->Print(os, offset + " ");
00313    }
00314    if (fNodeR) {
00315       os << offset << "PrInt_t right node" << std::endl;
00316       fNodeR->Print(os, offset + " ");
00317    }
00318    
00319    if (!fNodeL && !fNodeR) {
00320       os << std::endl;
00321    }
00322 }
00323 
00324 //-------------------------------------------------------------------------------------------
00325 template<class T>
00326 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
00327                        const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
00328 {
00329    // This is a global templated function that searches for k-nearest neighbors.
00330    // list contains k or less nodes that are closest to event.
00331    // only nodes with positive weights are added to list.
00332    // each node contains maximum and minimum values of splitting variable
00333    // for all its children - this range is checked to avoid descending into
00334    // nodes that are defintely outside current minimum neighbourhood.
00335    //
00336    // This function should be modified with care.
00337    //
00338 
00339    if (!node || nfind < 1) {
00340       return 0;
00341    }
00342 
00343    const Float_t value = event.GetVar(node->GetMod());     
00344 
00345    if (node->GetWeight() > 0.0) {
00346 
00347       Float_t max_dist = 0.0;
00348 
00349       if (!nlist.empty()) {
00350 
00351          max_dist = nlist.back().second;
00352          
00353          if (nlist.size() == nfind) {
00354             if (value > node->GetVarMax() && 
00355                 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
00356                return 0;
00357             }  
00358             if (value < node->GetVarMin() && 
00359                 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
00360                return 0;
00361             }
00362          }      
00363       }
00364 
00365       const Float_t distance = event.GetDist(node->GetEvent());
00366       
00367       Bool_t insert_this = kFALSE;
00368       Bool_t remove_back = kFALSE;
00369       
00370       if (nlist.size() < nfind) {
00371          insert_this = kTRUE;
00372       }
00373       else if (nlist.size() == nfind) {
00374          if (distance < max_dist) {
00375             insert_this = kTRUE;
00376             remove_back = kTRUE;
00377          }
00378       }
00379       else {
00380          std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
00381          return 1;
00382       }
00383       
00384       if (insert_this) {
00385          // need typename keyword because qualified dependent names 
00386          // are not valid types unless preceded by 'typename'.
00387          typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
00388          
00389          // find a place where current node should be inserted
00390          for (; lit != nlist.end(); ++lit) {
00391             if (distance < lit->second) {
00392                break;
00393             }
00394             else {
00395                continue;
00396             }
00397          }
00398          
00399          nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
00400          
00401          if (remove_back) {
00402             nlist.pop_back();
00403          }
00404       }
00405    }
00406    
00407    UInt_t count = 1;
00408    if (node->GetNodeL() && node->GetNodeR()) {
00409       if (value < node->GetVarDis()) {
00410          count += Find(nlist, node->GetNodeL(), event, nfind);
00411          count += Find(nlist, node->GetNodeR(), event, nfind);
00412       }
00413       else { 
00414          count += Find(nlist, node->GetNodeR(), event, nfind);
00415          count += Find(nlist, node->GetNodeL(), event, nfind);
00416       }
00417    }
00418    else {
00419       if (node->GetNodeL()) {
00420          count += Find(nlist, node->GetNodeL(), event, nfind);
00421       }
00422       if (node->GetNodeR()) {
00423          count += Find(nlist, node->GetNodeR(), event, nfind);
00424       }
00425    }
00426    
00427    return count;
00428 }
00429 
00430 
00431 //-------------------------------------------------------------------------------------------
00432 template<class T>
00433 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
00434                        const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
00435 {
00436    // This is a global templated function that searches for k-nearest neighbors.
00437    // list contains all nodes that are closest to event 
00438    // and have sum of event weights >= nfind.
00439    // Only nodes with positive weights are added to list.
00440    // Requirement for used classes:
00441    //  - each node contains maximum and minimum values of splitting variable
00442    //    for all its children
00443    //  - min and max range is checked to avoid descending into
00444    //    nodes that are defintely outside current minimum neighbourhood.
00445    //
00446    // This function should be modified with care.
00447    //
00448 
00449    if (!node || !(nfind < 0.0)) {
00450       return 0;
00451    }
00452 
00453    const Float_t value = event.GetVar(node->GetMod());     
00454 
00455    if (node->GetWeight() > 0.0) {
00456 
00457       Float_t max_dist = 0.0;
00458 
00459       if (!nlist.empty()) {
00460 
00461          max_dist = nlist.back().second;
00462          
00463          if (!(ncurr < nfind)) {
00464             if (value > node->GetVarMax() && 
00465                 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
00466                return 0;
00467             }  
00468             if (value < node->GetVarMin() && 
00469                 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
00470                return 0;
00471             }
00472          }      
00473       }
00474 
00475       const Float_t distance = event.GetDist(node->GetEvent());
00476       
00477       Bool_t insert_this = kFALSE;
00478       
00479       if (ncurr < nfind) {
00480          insert_this = kTRUE;
00481       }
00482       else if (!nlist.empty()) {
00483          if (distance < max_dist) {
00484             insert_this = kTRUE;
00485          }
00486       }
00487       else {
00488          std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
00489          return 1;
00490       }
00491       
00492       if (insert_this) {
00493          // (re)compute total current weight when inserting a new node
00494          ncurr = 0;
00495 
00496          // need typename keyword because qualified dependent names 
00497          // are not valid types unless preceded by 'typename'.
00498          typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
00499 
00500          // find a place where current node should be inserted
00501          for (; lit != nlist.end(); ++lit) {
00502             if (distance < lit->second) {
00503                break;
00504             }
00505 
00506             ncurr += lit -> first -> GetWeight();
00507          }
00508          
00509          lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
00510          
00511          for (; lit != nlist.end(); ++lit) {
00512             ncurr += lit -> first -> GetWeight();
00513             if (!(ncurr < nfind)) {
00514                ++lit;
00515                break;
00516             }
00517          }
00518 
00519          if(lit != nlist.end())
00520             {
00521                nlist.erase(lit, nlist.end());
00522             }
00523       }
00524    }   
00525    
00526    UInt_t count = 1;
00527    if (node->GetNodeL() && node->GetNodeR()) {
00528       if (value < node->GetVarDis()) {
00529          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00530          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00531       }
00532       else { 
00533          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00534          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00535       }
00536    }
00537    else {
00538       if (node->GetNodeL()) {
00539          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00540       }
00541       if (node->GetNodeR()) {
00542          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00543       }
00544    }
00545    
00546    return count;
00547 }
00548 
00549 #endif
00550 

Generated on Tue Jul 5 14:27:33 2011 for ROOT_528-00b_version by  doxygen 1.5.1