00001 
00002 
00003 
00004 
00005 
00006 
00007 
00008 
00009 
00010 
00011 
00012 
00013 
00014 
00015 
00016 
00017 
00018 
00019 
00020 
00021 
00022 
00023 
00024 
00025 
00026 #ifndef ROOT_TMVA_NodekNN
00027 #define ROOT_TMVA_NodekNN
00028 
00029 
00030 #include <list>
00031 #include <string>
00032 #include <iostream>
00033 
00034 
00035 #ifndef ROOT_Rtypes
00036 #include "Rtypes.h"
00037 #endif
00038 
00039 
00040 
00041 
00042 
00043 
00044 
00045 
00046 
00047 
00048 
00049 
00050 
00051 
00052 
00053 
00054 
00055 
00056 
00057 
00058 
00059 
00060 
00061 
00062 
00063 
00064 
00065 namespace TMVA
00066 {
00067    namespace kNN
00068    {
00069       template <class T>
00070       class Node
00071       {
00072 
00073       public:
00074       
00075          Node(const Node *parent, const T &event, Int_t mod);
00076          ~Node();
00077 
00078          const Node* Add(const T &event, UInt_t depth);
00079       
00080          void SetNodeL(Node *node);
00081          void SetNodeR(Node *node);
00082       
00083          const T& GetEvent() const;
00084 
00085          const Node* GetNodeL() const;
00086          const Node* GetNodeR() const;
00087          const Node* GetNodeP() const;
00088       
00089          Double_t GetWeight() const;
00090 
00091          Float_t GetVarDis() const;
00092          Float_t GetVarMin() const;
00093          Float_t GetVarMax() const;
00094 
00095          UInt_t GetMod() const;
00096 
00097          void Print() const;
00098          void Print(std::ostream& os, const std::string &offset = "") const;
00099 
00100       private: 
00101 
00102          
00103          
00104          Node();
00105          Node(const Node &);
00106          const Node& operator=(const Node &);
00107 
00108       private:
00109 
00110          const Node* fNodeP;
00111       
00112          Node* fNodeL;
00113          Node* fNodeR;      
00114       
00115          const T fEvent;
00116       
00117          const Float_t fVarDis;
00118 
00119          Float_t fVarMin;
00120          Float_t fVarMax;
00121 
00122          const UInt_t fMod;
00123       };
00124 
00125       
00126       template<class T>
00127       UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
00128                         const Node<T> *node, const T &event, UInt_t nfind);
00129 
00130       
00131       
00132       template<class T>
00133       UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
00134                   const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
00135 
00136       
00137       template <class T>
00138       UInt_t Depth(const Node<T> *node);
00139 
00140       
00141       
00142       
00143 
00144       
00145       
00146       
00147       template <class T>
00148       inline void Node<T>::SetNodeL(Node<T> *node)
00149       {
00150          fNodeL = node;
00151       }
00152 
00153       template <class T>
00154       inline void Node<T>::SetNodeR(Node<T> *node)
00155       {
00156          fNodeR = node;
00157       }
00158 
00159       template <class T>
00160       inline const T& Node<T>::GetEvent() const
00161       {
00162          return fEvent;
00163       }
00164 
00165       template <class T>
00166       inline const Node<T>* Node<T>::GetNodeL() const
00167       {
00168          return fNodeL;
00169       }
00170 
00171       template <class T>
00172       inline const Node<T>* Node<T>::GetNodeR() const
00173       {
00174          return fNodeR;
00175       }
00176 
00177       template <class T>
00178       inline const Node<T>* Node<T>::GetNodeP() const
00179       {
00180          return fNodeP;
00181       }
00182 
00183       template <class T>
00184       inline Double_t Node<T>::GetWeight() const
00185       {
00186          return fEvent.GetWeight();
00187       }
00188 
00189       template <class T>
00190       inline Float_t Node<T>::GetVarDis() const
00191       {
00192          return fVarDis;
00193       }
00194 
00195       template <class T>
00196       inline Float_t Node<T>::GetVarMin() const
00197       {
00198          return fVarMin;
00199       }
00200 
00201       template <class T>
00202       inline Float_t Node<T>::GetVarMax() const
00203       {
00204          return fVarMax;
00205       }
00206 
00207       template <class T>
00208       inline UInt_t Node<T>::GetMod() const
00209       {
00210          return fMod;
00211       }
00212 
00213       
00214       
00215       
00216       template <class T>
00217       inline UInt_t Depth(const Node<T> *node)
00218       {
00219          if (!node) return 0;
00220          else return Depth(node->GetNodeP()) + 1;
00221       }
00222 
00223    } 
00224 } 
00225 
00226 
00227 template<class T>
00228 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod) 
00229    :fNodeP(parent),
00230     fNodeL(0),
00231     fNodeR(0),
00232     fEvent(event),
00233     fVarDis(event.GetVar(mod)),
00234     fVarMin(fVarDis),
00235     fVarMax(fVarDis),
00236     fMod(mod)
00237 {}
00238 
00239 
00240 template<class T>
00241 TMVA::kNN::Node<T>::~Node()
00242 {
00243    if (fNodeL) delete fNodeL;
00244    if (fNodeR) delete fNodeR;
00245 }
00246 
00247 
00248 template<class T>
00249 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
00250 {
00251    
00252    
00253    
00254    
00255    assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
00256    
00257    const Float_t value = event.GetVar(fMod);
00258    
00259    fVarMin = std::min(fVarMin, value);
00260    fVarMax = std::max(fVarMax, value);
00261    
00262    Node<T> *node = 0;
00263    if (value < fVarDis) {
00264       if (fNodeL)
00265          {
00266             return fNodeL->Add(event, depth + 1);
00267          }
00268       else {
00269          fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
00270          node = fNodeL;
00271       }
00272    }
00273    else {
00274       if (fNodeR) {
00275          return fNodeR->Add(event, depth + 1);
00276       }
00277       else {
00278          fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
00279          node = fNodeR;
00280       }      
00281    }
00282    
00283    return node;
00284 }
00285    
00286 
00287 template<class T>
00288 void TMVA::kNN::Node<T>::Print() const
00289 {
00290    Print(std::cout);
00291 }
00292    
00293 
00294 template<class T>
00295 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
00296 {
00297    os << offset << "-----------------------------------------------------------" << std::endl;
00298    os << offset << "Node: mod " << fMod 
00299       << " at " << fVarDis 
00300       << " with weight: " << GetWeight() << std::endl
00301       << offset << fEvent;
00302    
00303    if (fNodeL) {
00304       os << offset << "Has left node " << std::endl;
00305    }
00306    if (fNodeR) {
00307       os << offset << "Has right node" << std::endl;
00308    }
00309    
00310    if (fNodeL) {
00311       os << offset << "PrInt_t left node " << std::endl;
00312       fNodeL->Print(os, offset + " ");
00313    }
00314    if (fNodeR) {
00315       os << offset << "PrInt_t right node" << std::endl;
00316       fNodeR->Print(os, offset + " ");
00317    }
00318    
00319    if (!fNodeL && !fNodeR) {
00320       os << std::endl;
00321    }
00322 }
00323 
00324 
00325 template<class T>
00326 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
00327                        const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
00328 {
00329    
00330    
00331    
00332    
00333    
00334    
00335    
00336    
00337    
00338 
00339    if (!node || nfind < 1) {
00340       return 0;
00341    }
00342 
00343    const Float_t value = event.GetVar(node->GetMod());     
00344 
00345    if (node->GetWeight() > 0.0) {
00346 
00347       Float_t max_dist = 0.0;
00348 
00349       if (!nlist.empty()) {
00350 
00351          max_dist = nlist.back().second;
00352          
00353          if (nlist.size() == nfind) {
00354             if (value > node->GetVarMax() && 
00355                 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
00356                return 0;
00357             }  
00358             if (value < node->GetVarMin() && 
00359                 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
00360                return 0;
00361             }
00362          }      
00363       }
00364 
00365       const Float_t distance = event.GetDist(node->GetEvent());
00366       
00367       Bool_t insert_this = kFALSE;
00368       Bool_t remove_back = kFALSE;
00369       
00370       if (nlist.size() < nfind) {
00371          insert_this = kTRUE;
00372       }
00373       else if (nlist.size() == nfind) {
00374          if (distance < max_dist) {
00375             insert_this = kTRUE;
00376             remove_back = kTRUE;
00377          }
00378       }
00379       else {
00380          std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
00381          return 1;
00382       }
00383       
00384       if (insert_this) {
00385          
00386          
00387          typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
00388          
00389          
00390          for (; lit != nlist.end(); ++lit) {
00391             if (distance < lit->second) {
00392                break;
00393             }
00394             else {
00395                continue;
00396             }
00397          }
00398          
00399          nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
00400          
00401          if (remove_back) {
00402             nlist.pop_back();
00403          }
00404       }
00405    }
00406    
00407    UInt_t count = 1;
00408    if (node->GetNodeL() && node->GetNodeR()) {
00409       if (value < node->GetVarDis()) {
00410          count += Find(nlist, node->GetNodeL(), event, nfind);
00411          count += Find(nlist, node->GetNodeR(), event, nfind);
00412       }
00413       else { 
00414          count += Find(nlist, node->GetNodeR(), event, nfind);
00415          count += Find(nlist, node->GetNodeL(), event, nfind);
00416       }
00417    }
00418    else {
00419       if (node->GetNodeL()) {
00420          count += Find(nlist, node->GetNodeL(), event, nfind);
00421       }
00422       if (node->GetNodeR()) {
00423          count += Find(nlist, node->GetNodeR(), event, nfind);
00424       }
00425    }
00426    
00427    return count;
00428 }
00429 
00430 
00431 
00432 template<class T>
00433 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
00434                        const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
00435 {
00436    
00437    
00438    
00439    
00440    
00441    
00442    
00443    
00444    
00445    
00446    
00447    
00448 
00449    if (!node || !(nfind < 0.0)) {
00450       return 0;
00451    }
00452 
00453    const Float_t value = event.GetVar(node->GetMod());     
00454 
00455    if (node->GetWeight() > 0.0) {
00456 
00457       Float_t max_dist = 0.0;
00458 
00459       if (!nlist.empty()) {
00460 
00461          max_dist = nlist.back().second;
00462          
00463          if (!(ncurr < nfind)) {
00464             if (value > node->GetVarMax() && 
00465                 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
00466                return 0;
00467             }  
00468             if (value < node->GetVarMin() && 
00469                 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
00470                return 0;
00471             }
00472          }      
00473       }
00474 
00475       const Float_t distance = event.GetDist(node->GetEvent());
00476       
00477       Bool_t insert_this = kFALSE;
00478       
00479       if (ncurr < nfind) {
00480          insert_this = kTRUE;
00481       }
00482       else if (!nlist.empty()) {
00483          if (distance < max_dist) {
00484             insert_this = kTRUE;
00485          }
00486       }
00487       else {
00488          std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
00489          return 1;
00490       }
00491       
00492       if (insert_this) {
00493          
00494          ncurr = 0;
00495 
00496          
00497          
00498          typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
00499 
00500          
00501          for (; lit != nlist.end(); ++lit) {
00502             if (distance < lit->second) {
00503                break;
00504             }
00505 
00506             ncurr += lit -> first -> GetWeight();
00507          }
00508          
00509          lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
00510          
00511          for (; lit != nlist.end(); ++lit) {
00512             ncurr += lit -> first -> GetWeight();
00513             if (!(ncurr < nfind)) {
00514                ++lit;
00515                break;
00516             }
00517          }
00518 
00519          if(lit != nlist.end())
00520             {
00521                nlist.erase(lit, nlist.end());
00522             }
00523       }
00524    }   
00525    
00526    UInt_t count = 1;
00527    if (node->GetNodeL() && node->GetNodeR()) {
00528       if (value < node->GetVarDis()) {
00529          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00530          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00531       }
00532       else { 
00533          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00534          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00535       }
00536    }
00537    else {
00538       if (node->GetNodeL()) {
00539          count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00540       }
00541       if (node->GetNodeR()) {
00542          count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00543       }
00544    }
00545    
00546    return count;
00547 }
00548 
00549 #endif
00550