00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026 #ifndef ROOT_TMVA_NodekNN
00027 #define ROOT_TMVA_NodekNN
00028
00029
00030 #include <list>
00031 #include <string>
00032 #include <iostream>
00033
00034
00035 #ifndef ROOT_Rtypes
00036 #include "Rtypes.h"
00037 #endif
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055
00056
00057
00058
00059
00060
00061
00062
00063
00064
00065 namespace TMVA
00066 {
00067 namespace kNN
00068 {
00069 template <class T>
00070 class Node
00071 {
00072
00073 public:
00074
00075 Node(const Node *parent, const T &event, Int_t mod);
00076 ~Node();
00077
00078 const Node* Add(const T &event, UInt_t depth);
00079
00080 void SetNodeL(Node *node);
00081 void SetNodeR(Node *node);
00082
00083 const T& GetEvent() const;
00084
00085 const Node* GetNodeL() const;
00086 const Node* GetNodeR() const;
00087 const Node* GetNodeP() const;
00088
00089 Double_t GetWeight() const;
00090
00091 Float_t GetVarDis() const;
00092 Float_t GetVarMin() const;
00093 Float_t GetVarMax() const;
00094
00095 UInt_t GetMod() const;
00096
00097 void Print() const;
00098 void Print(std::ostream& os, const std::string &offset = "") const;
00099
00100 private:
00101
00102
00103
00104 Node();
00105 Node(const Node &);
00106 const Node& operator=(const Node &);
00107
00108 private:
00109
00110 const Node* fNodeP;
00111
00112 Node* fNodeL;
00113 Node* fNodeR;
00114
00115 const T fEvent;
00116
00117 const Float_t fVarDis;
00118
00119 Float_t fVarMin;
00120 Float_t fVarMax;
00121
00122 const UInt_t fMod;
00123 };
00124
00125
00126 template<class T>
00127 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
00128 const Node<T> *node, const T &event, UInt_t nfind);
00129
00130
00131
00132 template<class T>
00133 UInt_t Find(std::list<std::pair<const Node<T> *, Float_t> > &nlist,
00134 const Node<T> *node, const T &event, Double_t nfind, Double_t ncurr);
00135
00136
00137 template <class T>
00138 UInt_t Depth(const Node<T> *node);
00139
00140
00141
00142
00143
00144
00145
00146
00147 template <class T>
00148 inline void Node<T>::SetNodeL(Node<T> *node)
00149 {
00150 fNodeL = node;
00151 }
00152
00153 template <class T>
00154 inline void Node<T>::SetNodeR(Node<T> *node)
00155 {
00156 fNodeR = node;
00157 }
00158
00159 template <class T>
00160 inline const T& Node<T>::GetEvent() const
00161 {
00162 return fEvent;
00163 }
00164
00165 template <class T>
00166 inline const Node<T>* Node<T>::GetNodeL() const
00167 {
00168 return fNodeL;
00169 }
00170
00171 template <class T>
00172 inline const Node<T>* Node<T>::GetNodeR() const
00173 {
00174 return fNodeR;
00175 }
00176
00177 template <class T>
00178 inline const Node<T>* Node<T>::GetNodeP() const
00179 {
00180 return fNodeP;
00181 }
00182
00183 template <class T>
00184 inline Double_t Node<T>::GetWeight() const
00185 {
00186 return fEvent.GetWeight();
00187 }
00188
00189 template <class T>
00190 inline Float_t Node<T>::GetVarDis() const
00191 {
00192 return fVarDis;
00193 }
00194
00195 template <class T>
00196 inline Float_t Node<T>::GetVarMin() const
00197 {
00198 return fVarMin;
00199 }
00200
00201 template <class T>
00202 inline Float_t Node<T>::GetVarMax() const
00203 {
00204 return fVarMax;
00205 }
00206
00207 template <class T>
00208 inline UInt_t Node<T>::GetMod() const
00209 {
00210 return fMod;
00211 }
00212
00213
00214
00215
00216 template <class T>
00217 inline UInt_t Depth(const Node<T> *node)
00218 {
00219 if (!node) return 0;
00220 else return Depth(node->GetNodeP()) + 1;
00221 }
00222
00223 }
00224 }
00225
00226
00227 template<class T>
00228 TMVA::kNN::Node<T>::Node(const Node<T> *parent, const T &event, const Int_t mod)
00229 :fNodeP(parent),
00230 fNodeL(0),
00231 fNodeR(0),
00232 fEvent(event),
00233 fVarDis(event.GetVar(mod)),
00234 fVarMin(fVarDis),
00235 fVarMax(fVarDis),
00236 fMod(mod)
00237 {}
00238
00239
00240 template<class T>
00241 TMVA::kNN::Node<T>::~Node()
00242 {
00243 if (fNodeL) delete fNodeL;
00244 if (fNodeR) delete fNodeR;
00245 }
00246
00247
00248 template<class T>
00249 const TMVA::kNN::Node<T>* TMVA::kNN::Node<T>::Add(const T &event, const UInt_t depth)
00250 {
00251
00252
00253
00254
00255 assert(fMod == depth % event.GetNVar() && "Wrong recursive depth in Node<>::Add");
00256
00257 const Float_t value = event.GetVar(fMod);
00258
00259 fVarMin = std::min(fVarMin, value);
00260 fVarMax = std::max(fVarMax, value);
00261
00262 Node<T> *node = 0;
00263 if (value < fVarDis) {
00264 if (fNodeL)
00265 {
00266 return fNodeL->Add(event, depth + 1);
00267 }
00268 else {
00269 fNodeL = new Node<T>(this, event, (depth + 1) % event.GetNVar());
00270 node = fNodeL;
00271 }
00272 }
00273 else {
00274 if (fNodeR) {
00275 return fNodeR->Add(event, depth + 1);
00276 }
00277 else {
00278 fNodeR = new Node<T>(this, event, (depth + 1) % event.GetNVar());
00279 node = fNodeR;
00280 }
00281 }
00282
00283 return node;
00284 }
00285
00286
00287 template<class T>
00288 void TMVA::kNN::Node<T>::Print() const
00289 {
00290 Print(std::cout);
00291 }
00292
00293
00294 template<class T>
00295 void TMVA::kNN::Node<T>::Print(std::ostream& os, const std::string &offset) const
00296 {
00297 os << offset << "-----------------------------------------------------------" << std::endl;
00298 os << offset << "Node: mod " << fMod
00299 << " at " << fVarDis
00300 << " with weight: " << GetWeight() << std::endl
00301 << offset << fEvent;
00302
00303 if (fNodeL) {
00304 os << offset << "Has left node " << std::endl;
00305 }
00306 if (fNodeR) {
00307 os << offset << "Has right node" << std::endl;
00308 }
00309
00310 if (fNodeL) {
00311 os << offset << "PrInt_t left node " << std::endl;
00312 fNodeL->Print(os, offset + " ");
00313 }
00314 if (fNodeR) {
00315 os << offset << "PrInt_t right node" << std::endl;
00316 fNodeR->Print(os, offset + " ");
00317 }
00318
00319 if (!fNodeL && !fNodeR) {
00320 os << std::endl;
00321 }
00322 }
00323
00324
00325 template<class T>
00326 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
00327 const TMVA::kNN::Node<T> *node, const T &event, const UInt_t nfind)
00328 {
00329
00330
00331
00332
00333
00334
00335
00336
00337
00338
00339 if (!node || nfind < 1) {
00340 return 0;
00341 }
00342
00343 const Float_t value = event.GetVar(node->GetMod());
00344
00345 if (node->GetWeight() > 0.0) {
00346
00347 Float_t max_dist = 0.0;
00348
00349 if (!nlist.empty()) {
00350
00351 max_dist = nlist.back().second;
00352
00353 if (nlist.size() == nfind) {
00354 if (value > node->GetVarMax() &&
00355 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
00356 return 0;
00357 }
00358 if (value < node->GetVarMin() &&
00359 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
00360 return 0;
00361 }
00362 }
00363 }
00364
00365 const Float_t distance = event.GetDist(node->GetEvent());
00366
00367 Bool_t insert_this = kFALSE;
00368 Bool_t remove_back = kFALSE;
00369
00370 if (nlist.size() < nfind) {
00371 insert_this = kTRUE;
00372 }
00373 else if (nlist.size() == nfind) {
00374 if (distance < max_dist) {
00375 insert_this = kTRUE;
00376 remove_back = kTRUE;
00377 }
00378 }
00379 else {
00380 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
00381 return 1;
00382 }
00383
00384 if (insert_this) {
00385
00386
00387 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
00388
00389
00390 for (; lit != nlist.end(); ++lit) {
00391 if (distance < lit->second) {
00392 break;
00393 }
00394 else {
00395 continue;
00396 }
00397 }
00398
00399 nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
00400
00401 if (remove_back) {
00402 nlist.pop_back();
00403 }
00404 }
00405 }
00406
00407 UInt_t count = 1;
00408 if (node->GetNodeL() && node->GetNodeR()) {
00409 if (value < node->GetVarDis()) {
00410 count += Find(nlist, node->GetNodeL(), event, nfind);
00411 count += Find(nlist, node->GetNodeR(), event, nfind);
00412 }
00413 else {
00414 count += Find(nlist, node->GetNodeR(), event, nfind);
00415 count += Find(nlist, node->GetNodeL(), event, nfind);
00416 }
00417 }
00418 else {
00419 if (node->GetNodeL()) {
00420 count += Find(nlist, node->GetNodeL(), event, nfind);
00421 }
00422 if (node->GetNodeR()) {
00423 count += Find(nlist, node->GetNodeR(), event, nfind);
00424 }
00425 }
00426
00427 return count;
00428 }
00429
00430
00431
00432 template<class T>
00433 UInt_t TMVA::kNN::Find(std::list<std::pair<const TMVA::kNN::Node<T> *, Float_t> > &nlist,
00434 const TMVA::kNN::Node<T> *node, const T &event, const Double_t nfind, Double_t ncurr)
00435 {
00436
00437
00438
00439
00440
00441
00442
00443
00444
00445
00446
00447
00448
00449 if (!node || !(nfind < 0.0)) {
00450 return 0;
00451 }
00452
00453 const Float_t value = event.GetVar(node->GetMod());
00454
00455 if (node->GetWeight() > 0.0) {
00456
00457 Float_t max_dist = 0.0;
00458
00459 if (!nlist.empty()) {
00460
00461 max_dist = nlist.back().second;
00462
00463 if (!(ncurr < nfind)) {
00464 if (value > node->GetVarMax() &&
00465 event.GetDist(node->GetVarMax(), node->GetMod()) > max_dist) {
00466 return 0;
00467 }
00468 if (value < node->GetVarMin() &&
00469 event.GetDist(node->GetVarMin(), node->GetMod()) > max_dist) {
00470 return 0;
00471 }
00472 }
00473 }
00474
00475 const Float_t distance = event.GetDist(node->GetEvent());
00476
00477 Bool_t insert_this = kFALSE;
00478
00479 if (ncurr < nfind) {
00480 insert_this = kTRUE;
00481 }
00482 else if (!nlist.empty()) {
00483 if (distance < max_dist) {
00484 insert_this = kTRUE;
00485 }
00486 }
00487 else {
00488 std::cerr << "TMVA::kNN::Find() - logic error in recursive procedure" << std::endl;
00489 return 1;
00490 }
00491
00492 if (insert_this) {
00493
00494 ncurr = 0;
00495
00496
00497
00498 typename std::list<std::pair<const Node<T> *, Float_t> >::iterator lit = nlist.begin();
00499
00500
00501 for (; lit != nlist.end(); ++lit) {
00502 if (distance < lit->second) {
00503 break;
00504 }
00505
00506 ncurr += lit -> first -> GetWeight();
00507 }
00508
00509 lit = nlist.insert(lit, std::pair<const Node<T> *, Float_t>(node, distance));
00510
00511 for (; lit != nlist.end(); ++lit) {
00512 ncurr += lit -> first -> GetWeight();
00513 if (!(ncurr < nfind)) {
00514 ++lit;
00515 break;
00516 }
00517 }
00518
00519 if(lit != nlist.end())
00520 {
00521 nlist.erase(lit, nlist.end());
00522 }
00523 }
00524 }
00525
00526 UInt_t count = 1;
00527 if (node->GetNodeL() && node->GetNodeR()) {
00528 if (value < node->GetVarDis()) {
00529 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00530 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00531 }
00532 else {
00533 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00534 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00535 }
00536 }
00537 else {
00538 if (node->GetNodeL()) {
00539 count += Find(nlist, node->GetNodeL(), event, nfind, ncurr);
00540 }
00541 if (node->GetNodeR()) {
00542 count += Find(nlist, node->GetNodeR(), event, nfind, ncurr);
00543 }
00544 }
00545
00546 return count;
00547 }
00548
00549 #endif
00550