00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039 #include <stdexcept>
00040 #include <cstdlib>
00041 #include <queue>
00042 #include <algorithm>
00043
00044 #if ROOT_VERSION_CODE >= 364802
00045 #ifndef ROOT_TMathBase
00046 #include "TMathBase.h"
00047 #endif
00048 #else
00049 #ifndef ROOT_TMath
00050 #include "TMath.h"
00051 #endif
00052 #endif
00053 #include "TMatrixDBase.h"
00054 #include "TObjString.h"
00055 #include "TTree.h"
00056
00057 #ifndef ROOT_TMVA_MsgLogger
00058 #include "TMVA/MsgLogger.h"
00059 #endif
00060 #ifndef ROOT_TMVA_MethodBase
00061 #include "TMVA/MethodBase.h"
00062 #endif
00063 #ifndef ROOT_TMVA_Tools
00064 #include "TMVA/Tools.h"
00065 #endif
00066 #ifndef ROOT_TMVA_DataSet
00067 #include "TMVA/DataSet.h"
00068 #endif
00069 #ifndef ROOT_TMVA_Event
00070 #include "TMVA/Event.h"
00071 #endif
00072 #ifndef ROOT_TMVA_BinarySearchTree
00073 #include "TMVA/BinarySearchTree.h"
00074 #endif
00075
00076 ClassImp(TMVA::BinarySearchTree)
00077
00078
00079 TMVA::BinarySearchTree::BinarySearchTree( void ) :
00080 BinaryTree(),
00081 fPeriod ( 1 ),
00082 fCurrentDepth( 0 ),
00083 fStatisticsIsValid( kFALSE ),
00084 fSumOfWeights( 0 ),
00085 fCanNormalize( kFALSE )
00086 {
00087
00088 fNEventsW[0]=fNEventsW[1]=0.;
00089 }
00090
00091
00092 TMVA::BinarySearchTree::BinarySearchTree( const BinarySearchTree &b)
00093 : BinaryTree(),
00094 fPeriod ( b.fPeriod ),
00095 fCurrentDepth( 0 ),
00096 fStatisticsIsValid( kFALSE ),
00097 fSumOfWeights( b.fSumOfWeights ),
00098 fCanNormalize( kFALSE )
00099 {
00100
00101 fNEventsW[0]=fNEventsW[1]=0.;
00102 Log() << kFATAL << " Copy constructor not implemented yet " << Endl;
00103 }
00104
00105
00106 TMVA::BinarySearchTree::~BinarySearchTree( void )
00107 {
00108
00109
00110 for(std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator pIt = fNormalizeTreeTable.begin();
00111 pIt != fNormalizeTreeTable.end(); pIt++) {
00112 delete pIt->second;
00113 }
00114 }
00115
00116
00117 TMVA::BinarySearchTree* TMVA::BinarySearchTree::CreateFromXML(void* node, UInt_t tmva_Version_Code ) {
00118
00119 std::string type("");
00120 gTools().ReadAttr(node,"type", type);
00121 BinarySearchTree* bt = new BinarySearchTree();
00122 bt->ReadXML( node, tmva_Version_Code );
00123 return bt;
00124 }
00125
00126
00127 void TMVA::BinarySearchTree::Insert( const Event* event )
00128 {
00129
00130 fCurrentDepth=0;
00131 fStatisticsIsValid = kFALSE;
00132
00133 if (this->GetRoot() == NULL) {
00134 this->SetRoot( new BinarySearchTreeNode(event));
00135
00136 this->GetRoot()->SetPos('s');
00137 this->GetRoot()->SetDepth(0);
00138 fNNodes = 1;
00139 fSumOfWeights = event->GetWeight();
00140 ((BinarySearchTreeNode*)this->GetRoot())->SetSelector((UInt_t)0);
00141 this->SetPeriode(event->GetNVariables());
00142 }
00143 else {
00144
00145 if (event->GetNVariables() != (UInt_t)this->GetPeriode()) {
00146 Log() << kFATAL << "<Insert> event vector length != Periode specified in Binary Tree" << Endl
00147 << "--- event size: " << event->GetNVariables() << " Periode: " << this->GetPeriode() << Endl
00148 << "--- and all this when trying filling the "<<fNNodes+1<<"th Node" << Endl;
00149 }
00150
00151 this->Insert(event, this->GetRoot());
00152 }
00153
00154
00155 if (fCanNormalize) fNormalizeTreeTable.push_back( std::make_pair(0.0,new const Event(*event)) );
00156 }
00157
00158
00159 void TMVA::BinarySearchTree::Insert( const Event *event,
00160 Node *node )
00161 {
00162
00163 fCurrentDepth++;
00164 fStatisticsIsValid = kFALSE;
00165
00166 if (node->GoesLeft(*event)){
00167 if (node->GetLeft() != NULL){
00168
00169 this->Insert(event, node->GetLeft());
00170 }
00171 else {
00172
00173 BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
00174 fNNodes++;
00175 fSumOfWeights += event->GetWeight();
00176 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
00177 current->SetParent(node);
00178 current->SetPos('l');
00179 current->SetDepth( node->GetDepth() + 1 );
00180 node->SetLeft(current);
00181 }
00182 }
00183 else if (node->GoesRight(*event)) {
00184 if (node->GetRight() != NULL) {
00185
00186 this->Insert(event, node->GetRight());
00187 }
00188 else {
00189
00190 BinarySearchTreeNode* current = new BinarySearchTreeNode(event);
00191 fNNodes++;
00192 fSumOfWeights += event->GetWeight();
00193 current->SetSelector(fCurrentDepth%((Int_t)event->GetNVariables()));
00194 current->SetParent(node);
00195 current->SetPos('r');
00196 current->SetDepth( node->GetDepth() + 1 );
00197 node->SetRight(current);
00198 }
00199 }
00200 else Log() << kFATAL << "<Insert> neither left nor right :)" << Endl;
00201 }
00202
00203
00204 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search( Event* event ) const
00205 {
00206
00207 return this->Search( event, this->GetRoot() );
00208 }
00209
00210
00211 TMVA::BinarySearchTreeNode* TMVA::BinarySearchTree::Search(Event* event, Node* node) const
00212 {
00213
00214 if (node != NULL) {
00215
00216 if (((BinarySearchTreeNode*)(node))->EqualsMe(*event))
00217 return (BinarySearchTreeNode*)node;
00218 if (node->GoesLeft(*event))
00219 return this->Search(event, node->GetLeft());
00220 else
00221 return this->Search(event, node->GetRight());
00222 }
00223 else return NULL;
00224 }
00225
00226
00227 Double_t TMVA::BinarySearchTree::GetSumOfWeights( void ) const
00228 {
00229
00230 if (fSumOfWeights <= 0) {
00231 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
00232 << " I call CalcStatistics which hopefully fixes things"
00233 << Endl;
00234 }
00235 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
00236
00237 return fSumOfWeights;
00238 }
00239
00240
00241 Double_t TMVA::BinarySearchTree::GetSumOfWeights( Int_t theType ) const
00242 {
00243
00244 if (fSumOfWeights <= 0) {
00245 Log() << kWARNING << "you asked for the SumOfWeights, which is not filled yet"
00246 << " I call CalcStatistics which hopefully fixes things"
00247 << Endl;
00248 }
00249 if (fSumOfWeights <= 0) Log() << kFATAL << " Zero events in your Search Tree" <<Endl;
00250
00251 return fNEventsW[ ( theType == Types::kSignal) ? 0 : 1 ];
00252 }
00253
00254
00255 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, const std::vector<Int_t>& theVars,
00256 Int_t theType )
00257 {
00258
00259
00260 fPeriod = theVars.size();
00261 return Fill(events, theType);
00262 }
00263
00264
00265 Double_t TMVA::BinarySearchTree::Fill( const std::vector<Event*>& events, Int_t theType )
00266 {
00267
00268
00269 UInt_t n=events.size();
00270
00271 UInt_t nevents = 0;
00272 if (fSumOfWeights != 0) {
00273 Log() << kWARNING
00274 << "You are filling a search three that is not empty.. "
00275 << " do you know what you are doing?"
00276 << Endl;
00277 }
00278 for (UInt_t ievt=0; ievt<n; ievt++) {
00279
00280 if (theType == -1 || (Int_t(events[ievt]->GetClass()) == theType) ) {
00281 this->Insert( events[ievt] );
00282 nevents++;
00283 fSumOfWeights += events[ievt]->GetWeight();
00284 }
00285 }
00286 CalcStatistics();
00287
00288 return fSumOfWeights;
00289 }
00290
00291
00292 void TMVA::BinarySearchTree::NormalizeTree ( std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftBound,
00293 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightBound,
00294 UInt_t actDim )
00295 {
00296
00297
00298
00299 if (leftBound == rightBound) return;
00300
00301 if (actDim == fPeriod) actDim = 0;
00302 for (std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator i=leftBound; i!=rightBound; i++) {
00303 i->first = i->second->GetValue( actDim );
00304 }
00305
00306 std::sort( leftBound, rightBound );
00307
00308 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator leftTemp = leftBound;
00309 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator rightTemp = rightBound;
00310
00311
00312 while (true) {
00313 rightTemp--;
00314 if (rightTemp == leftTemp ) {
00315 break;
00316 }
00317 leftTemp++;
00318 if (leftTemp == rightTemp) {
00319 break;
00320 }
00321 }
00322
00323 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator mid = leftTemp;
00324 std::vector< std::pair<Double_t, const TMVA::Event*> >::iterator midTemp = mid;
00325
00326 if (mid!=leftBound) midTemp--;
00327
00328 while (mid != leftBound && mid->second->GetValue( actDim ) == midTemp->second->GetValue( actDim )) {
00329 mid--;
00330 midTemp--;
00331 }
00332
00333 Insert( mid->second );
00334
00335
00336
00337
00338 NormalizeTree( leftBound, mid, actDim+1 );
00339 mid++;
00340
00341
00342 NormalizeTree( mid, rightBound, actDim+1 );
00343
00344
00345 return;
00346 }
00347
00348
00349 void TMVA::BinarySearchTree::NormalizeTree()
00350 {
00351
00352 SetNormalize( kFALSE );
00353 Clear( NULL );
00354 this->SetRoot(NULL);
00355 NormalizeTree( fNormalizeTreeTable.begin(), fNormalizeTreeTable.end(), 0 );
00356 }
00357
00358
00359 void TMVA::BinarySearchTree::Clear( Node* n )
00360 {
00361
00362 BinarySearchTreeNode* currentNode = (BinarySearchTreeNode*)(n == NULL ? this->GetRoot() : n);
00363
00364 if (currentNode->GetLeft() != 0) Clear( currentNode->GetLeft() );
00365 if (currentNode->GetRight() != 0) Clear( currentNode->GetRight() );
00366
00367 if (n != NULL) delete n;
00368
00369 return;
00370 }
00371
00372
00373 Double_t TMVA::BinarySearchTree::SearchVolume( Volume* volume,
00374 std::vector<const BinarySearchTreeNode*>* events )
00375 {
00376
00377
00378 return SearchVolume( this->GetRoot(), volume, 0, events );
00379 }
00380
00381
00382 Double_t TMVA::BinarySearchTree::SearchVolume( Node* t, Volume* volume, Int_t depth,
00383 std::vector<const BinarySearchTreeNode*>* events )
00384 {
00385
00386
00387
00388 if (t==NULL) return 0;
00389
00390 BinarySearchTreeNode* st = (BinarySearchTreeNode*)t;
00391
00392 Double_t count = 0.0;
00393 if (InVolume( st->GetEventV(), volume )) {
00394 count += st->GetWeight();
00395 if (NULL != events) events->push_back( st );
00396 }
00397 if (st->GetLeft()==NULL && st->GetRight()==NULL) {
00398
00399 return count;
00400 }
00401
00402 Bool_t tl, tr;
00403 Int_t d = depth%this->GetPeriode();
00404 if (d != st->GetSelector()) {
00405 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
00406 << d << " != " << "node "<< st->GetSelector() << Endl;
00407 }
00408 tl = (*(volume->fLower))[d] < st->GetEventV()[d];
00409 tr = (*(volume->fUpper))[d] >= st->GetEventV()[d];
00410
00411 if (tl) count += SearchVolume( st->GetLeft(), volume, (depth+1), events );
00412 if (tr) count += SearchVolume( st->GetRight(), volume, (depth+1), events );
00413
00414 return count;
00415 }
00416
00417 Bool_t TMVA::BinarySearchTree::InVolume(const std::vector<Float_t>& event, Volume* volume ) const
00418 {
00419
00420
00421 Bool_t result = false;
00422 for (UInt_t ivar=0; ivar< fPeriod; ivar++) {
00423 result = ( (*(volume->fLower))[ivar] < event[ivar] &&
00424 (*(volume->fUpper))[ivar] >= event[ivar] );
00425 if (!result) break;
00426 }
00427 return result;
00428 }
00429
00430
00431 void TMVA::BinarySearchTree::CalcStatistics( Node* n )
00432 {
00433
00434 if (fStatisticsIsValid) return;
00435
00436 BinarySearchTreeNode * currentNode = (BinarySearchTreeNode*)n;
00437
00438
00439 if (n == NULL) {
00440 fSumOfWeights = 0;
00441 for (Int_t sb=0; sb<2; sb++) {
00442 fNEventsW[sb] = 0;
00443 fMeans[sb] = std::vector<Float_t>(fPeriod);
00444 fRMS[sb] = std::vector<Float_t>(fPeriod);
00445 fMin[sb] = std::vector<Float_t>(fPeriod);
00446 fMax[sb] = std::vector<Float_t>(fPeriod);
00447 fSum[sb] = std::vector<Double_t>(fPeriod);
00448 fSumSq[sb] = std::vector<Double_t>(fPeriod);
00449 for (UInt_t j=0; j<fPeriod; j++) {
00450 fMeans[sb][j] = fRMS[sb][j] = fSum[sb][j] = fSumSq[sb][j] = 0;
00451 fMin[sb][j] = FLT_MAX;
00452 fMax[sb][j] = -FLT_MAX;
00453 }
00454 }
00455 currentNode = (BinarySearchTreeNode*) this->GetRoot();
00456 if (currentNode == NULL) return;
00457 }
00458
00459 const std::vector<Float_t> & evtVec = currentNode->GetEventV();
00460 Double_t weight = currentNode->GetWeight();
00461
00462 Int_t type = currentNode->IsSignal() ? 0 : 1;
00463 fNEventsW[type] += weight;
00464 fSumOfWeights += weight;
00465
00466 for (UInt_t j=0; j<fPeriod; j++) {
00467 Float_t val = evtVec[j];
00468 fSum[type][j] += val*weight;
00469 fSumSq[type][j] += val*val*weight;
00470 if (val < fMin[type][j]) fMin[type][j] = val;
00471 if (val > fMax[type][j]) fMax[type][j] = val;
00472 }
00473
00474 if ( (currentNode->GetLeft() != NULL) ) CalcStatistics( currentNode->GetLeft() );
00475 if ( (currentNode->GetRight() != NULL) ) CalcStatistics( currentNode->GetRight() );
00476
00477 if (n == NULL) {
00478 for (Int_t sb=0; sb<2; sb++) {
00479 for (UInt_t j=0; j<fPeriod; j++) {
00480 if (fNEventsW[sb] == 0) { fMeans[sb][j] = fRMS[sb][j] = 0; continue; }
00481 fMeans[sb][j] = fSum[sb][j]/fNEventsW[sb];
00482 fRMS[sb][j] = TMath::Sqrt(fSumSq[sb][j]/fNEventsW[sb] - fMeans[sb][j]*fMeans[sb][j]);
00483 }
00484 }
00485 fStatisticsIsValid = kTRUE;
00486 }
00487
00488 return;
00489 }
00490
00491 Int_t TMVA::BinarySearchTree::SearchVolumeWithMaxLimit( Volume *volume, std::vector<const BinarySearchTreeNode*>* events,
00492 Int_t max_points )
00493 {
00494
00495
00496 if (this->GetRoot() == NULL) return 0;
00497
00498 std::queue< std::pair< const BinarySearchTreeNode*, Int_t > > queue;
00499 std::pair< const BinarySearchTreeNode*, Int_t > st = std::make_pair( (const BinarySearchTreeNode*)this->GetRoot(), 0 );
00500 queue.push( st );
00501
00502 Int_t count = 0;
00503
00504 while ( !queue.empty() ) {
00505 st = queue.front(); queue.pop();
00506
00507 if (count == max_points)
00508 return count;
00509
00510 if (InVolume( st.first->GetEventV(), volume )) {
00511 count++;
00512 if (NULL != events) events->push_back( st.first );
00513 }
00514
00515 Bool_t tl, tr;
00516 Int_t d = st.second;
00517 if ( d == Int_t(this->GetPeriode()) ) d = 0;
00518
00519 if (d != st.first->GetSelector()) {
00520 Log() << kFATAL << "<SearchVolume> selector in Searchvolume "
00521 << d << " != " << "node "<< st.first->GetSelector() << Endl;
00522 }
00523
00524 tl = (*(volume->fLower))[d] < st.first->GetEventV()[d] && st.first->GetLeft() != NULL;
00525 tr = (*(volume->fUpper))[d] >= st.first->GetEventV()[d] && st.first->GetRight() != NULL;
00526
00527 if (tl) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetLeft(), d+1 ) );
00528 if (tr) queue.push( std::make_pair( (const BinarySearchTreeNode*)st.first->GetRight(), d+1 ) );
00529 }
00530
00531 return count;
00532 }