DataSet.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DataSet.cxx 36999 2010-11-26 23:58:45Z stelzer $
00002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : DataSet                                                               *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header for description)                               *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00016  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00017  *                                                                                *
00018  * Copyright (c) 2006:                                                            *
00019  *      CERN, Switzerland                                                         *
00020  *      MPI-K Heidelberg, Germany                                                 *
00021  *                                                                                *
00022  * Redistribution and use in source and binary forms, with or without             *
00023  * modification, are permitted according to the terms listed in LICENSE           *
00024  * (http://tmva.sourceforge.net/LICENSE)                                          *
00025  **********************************************************************************/
00026 
00027 #include <vector>
00028 #include <algorithm>
00029 #include <cstdlib>
00030 #include <stdexcept>
00031 #include <algorithm>
00032 
00033 #ifndef ROOT_TMVA_DataSetInfo
00034 #include "TMVA/DataSetInfo.h"
00035 #endif
00036 #ifndef ROOT_TMVA_DataSet
00037 #include "TMVA/DataSet.h"
00038 #endif
00039 #ifndef ROOT_TMVA_Event
00040 #include "TMVA/Event.h"
00041 #endif
00042 #ifndef ROOT_TMVA_MsgLogger
00043 #include "TMVA/MsgLogger.h"
00044 #endif
00045 #ifndef ROOT_TMVA_ResultsRegression
00046 #include "TMVA/ResultsRegression.h"
00047 #endif
00048 #ifndef ROOT_TMVA_ResultsClassification
00049 #include "TMVA/ResultsClassification.h"
00050 #endif
00051 #ifndef ROOT_TMVA_ResultsMulticlass
00052 #include "TMVA/ResultsMulticlass.h"
00053 #endif
00054 #ifndef ROOT_TMVA_Configurable
00055 #include "TMVA/Configurable.h"
00056 #endif
00057 
00058 //_______________________________________________________________________
00059 TMVA::DataSet::DataSet(const DataSetInfo& dsi) 
00060    : fdsi(dsi),
00061      fEventCollection(4,(std::vector<Event*>*)0),
00062      fCurrentTreeIdx(0),
00063      fCurrentEventIdx(0),
00064      fHasNegativeEventWeights(kFALSE),
00065      fLogger( new MsgLogger(TString(TString("Dataset:")+dsi.GetName()).Data()) ),
00066      fTrainingBlockSize(0)
00067 {
00068    // constructor
00069    for (UInt_t i=0; i<4; i++) fEventCollection[i] = new std::vector<Event*>();
00070    
00071    fClassEvents.resize(4);
00072    fBlockBelongToTraining.reserve(10);
00073    fBlockBelongToTraining.push_back(kTRUE);
00074 
00075    // sampling
00076    fSamplingRandom = 0;
00077 
00078    Int_t treeNum = 2;
00079    fSampling.resize( treeNum );  
00080    fSamplingNEvents.resize( treeNum ); 
00081    fSamplingWeight.resize(treeNum);
00082   
00083    for (Int_t treeIdx = 0; treeIdx < treeNum; treeIdx++) {
00084       fSampling.at(treeIdx) = kFALSE;
00085       fSamplingNEvents.at(treeIdx) = 0;
00086       fSamplingWeight.at(treeIdx) = 1.0;
00087    }
00088 }
00089 
00090 //_______________________________________________________________________
00091 TMVA::DataSet::~DataSet() 
00092 {
00093    // destructor
00094 
00095    // delete event collection
00096    Bool_t deleteEvents=true; // dataset owns the events /JS
00097    DestroyCollection( Types::kTraining, deleteEvents );
00098    DestroyCollection( Types::kTesting, deleteEvents );
00099    
00100    fBlockBelongToTraining.clear();
00101    // delete results
00102    for (std::vector< std::map< TString, Results* > >::iterator it = fResults.begin(); it != fResults.end(); it++) {
00103       for (std::map< TString, Results* >::iterator itMap = (*it).begin(); itMap != (*it).end(); itMap++) {
00104          delete itMap->second;
00105       }
00106    }
00107 
00108    // delete sampling
00109    if (fSamplingRandom != 0 ) delete fSamplingRandom;
00110 
00111    std::vector< std::pair< Float_t, Long64_t >* >::iterator itEv;
00112    std::vector< std::vector<std::pair< Float_t, Long64_t >* > >::iterator treeIt;
00113    for (treeIt = fSamplingEventList.begin(); treeIt != fSamplingEventList.end(); treeIt++ ) {
00114       for (itEv = (*treeIt).begin(); itEv != (*treeIt).end(); itEv++) {
00115          delete (*itEv);
00116       }
00117    }
00118 
00119    // need also to delete fEventCollections[2] and [3], not sure if they are used
00120    DestroyCollection( Types::kValidation, deleteEvents );
00121    DestroyCollection( Types::kTrainingOriginal, deleteEvents );
00122 
00123    delete fLogger;
00124 }
00125 
00126 //_______________________________________________________________________
00127 void TMVA::DataSet::IncrementNClassEvents( Int_t type, UInt_t classNumber ) 
00128 {
00129    if (fClassEvents.size()<(UInt_t)(type+1)) fClassEvents.resize( type+1 );
00130    if (fClassEvents.at( type ).size() < classNumber+1) fClassEvents.at( type ).resize( classNumber+1 );
00131    fClassEvents.at( type ).at( classNumber ) += 1;
00132 }
00133 
00134 //_______________________________________________________________________
00135 void TMVA::DataSet::ClearNClassEvents( Int_t type ) 
00136 {
00137    if (fClassEvents.size()<(UInt_t)(type+1)) fClassEvents.resize( type+1 );
00138    fClassEvents.at( type ).clear();
00139 }
00140 
00141 //_______________________________________________________________________
00142 Long64_t TMVA::DataSet::GetNClassEvents( Int_t type, UInt_t classNumber ) 
00143 {
00144    try {
00145       return fClassEvents.at(type).at(classNumber);
00146    } 
00147    catch (std::out_of_range excpt) {
00148       ClassInfo* ci = fdsi.GetClassInfo( classNumber );
00149       Log() << kFATAL << "No " << (type==0?"training":(type==1?"testing":"_unknown_type_")) 
00150             << " events for class " << (ci==NULL?"_no_name_known_":ci->GetName()) << " (index # "<<classNumber<<")"
00151             << " available. Check if all class names are spelled correctly and if events are" 
00152             << " passing the selection cuts." << Endl;
00153    } 
00154    catch (...) {
00155       Log() << kFATAL << "ERROR/CAUGHT : DataSet/GetNClassEvents, .. unknown error" << Endl;
00156    }
00157    return 0;
00158 }
00159 
00160 //_______________________________________________________________________
00161 void TMVA::DataSet::DestroyCollection(Types::ETreeType type, Bool_t deleteEvents )
00162 {
00163    // destroys the event collection (events + vector)
00164    UInt_t i = TreeIndex(type);
00165    if (i>=fEventCollection.size() || fEventCollection[i]==0) return;
00166    if (deleteEvents) {
00167       for (UInt_t j=0; j<fEventCollection[i]->size(); j++) delete (*fEventCollection[i])[j];
00168    }
00169    delete fEventCollection[i];
00170    fEventCollection[i]=0;
00171 }
00172 
00173 //_______________________________________________________________________
00174 TMVA::Event* TMVA::DataSet::GetEvent() const
00175 {
00176    if (fSampling.size() > UInt_t(fCurrentTreeIdx) && fSampling.at(fCurrentTreeIdx)) {
00177       Long64_t iEvt = fSamplingSelected.at(fCurrentTreeIdx).at( fCurrentEventIdx )->second;
00178       return (*(fEventCollection.at(fCurrentTreeIdx))).at(iEvt);
00179    }
00180    else {
00181       return (*(fEventCollection.at(fCurrentTreeIdx))).at(fCurrentEventIdx);
00182    }
00183 }
00184 
00185 //_______________________________________________________________________
00186 UInt_t TMVA::DataSet::GetNVariables() const 
00187 {
00188    // access the number of variables through the datasetinfo
00189    return fdsi.GetNVariables();
00190 }
00191 
00192 //_______________________________________________________________________
00193 UInt_t TMVA::DataSet::GetNTargets() const 
00194 {
00195    // access the number of targets through the datasetinfo
00196    return fdsi.GetNTargets();
00197 }
00198 
00199 //_______________________________________________________________________
00200 UInt_t TMVA::DataSet::GetNSpectators() const 
00201 {
00202    // access the number of targets through the datasetinfo
00203    return fdsi.GetNSpectators();
00204 }
00205 
00206 //_______________________________________________________________________
00207 void TMVA::DataSet::AddEvent(Event * ev, Types::ETreeType type) 
00208 {
00209    // add event to event list
00210    // after which the event is owned by the dataset
00211    fEventCollection.at(Int_t(type))->push_back(ev);
00212    if (ev->GetWeight()<0) fHasNegativeEventWeights = kTRUE;
00213    fEvtCollIt=fEventCollection.at(fCurrentTreeIdx)->begin();
00214 }
00215 
00216 //_______________________________________________________________________
00217 void TMVA::DataSet::SetEventCollection(std::vector<TMVA::Event*>* events, Types::ETreeType type) 
00218 {
00219    // Sets the event collection (by DataSetFactory)
00220    Bool_t deleteEvents = true;
00221    DestroyCollection(type,deleteEvents);
00222 
00223    const Int_t t = TreeIndex(type);
00224    ClearNClassEvents( type );
00225    fEventCollection.at(t) = events;
00226    for (std::vector<Event*>::iterator it = fEventCollection.at(t)->begin(); it < fEventCollection.at(t)->end(); it++) {
00227       IncrementNClassEvents( t, (*it)->GetClass() );
00228    }
00229    fEvtCollIt=fEventCollection.at(fCurrentTreeIdx)->begin();
00230 }
00231 
00232 //_______________________________________________________________________
00233 TMVA::Results* TMVA::DataSet::GetResults( const TString & resultsName,
00234                                           Types::ETreeType type,
00235                                           Types::EAnalysisType analysistype ) 
00236 {
00237    //    TString info(resultsName+"/");
00238    //    switch(type) {
00239    //    case Types::kTraining: info += "kTraining/";  break;
00240    //    case Types::kTesting:  info += "kTesting/";   break;
00241    //    default: break;
00242    //    }
00243    //    switch(analysistype) {
00244    //    case Types::kClassification: info += "kClassification";  break;
00245    //    case Types::kRegression:     info += "kRegression";      break;
00246    //    case Types::kNoAnalysisType: info += "kNoAnalysisType";  break;
00247    //    case Types::kMaxAnalysisType:info += "kMaxAnalysisType"; break;
00248    //    }
00249 
00250    UInt_t t = TreeIndex(type);
00251    if (t<fResults.size()) {
00252       const std::map< TString, Results* >& resultsForType = fResults[t];
00253       std::map< TString, Results* >::const_iterator it = resultsForType.find(resultsName);
00254       if (it!=resultsForType.end()) {
00255          //Log() << kINFO << " GetResults("<<info<<") returns existing result." << Endl;
00256          return it->second;
00257       }
00258    }
00259    else {
00260       fResults.resize(t+1);
00261    }
00262 
00263    // nothing found
00264 
00265    Results * newresults = 0;
00266    switch(analysistype) {
00267    case Types::kClassification:
00268       newresults = new ResultsClassification(&fdsi);
00269       break;
00270    case Types::kRegression:
00271       newresults = new ResultsRegression(&fdsi);
00272       break;
00273    case Types::kMulticlass:
00274       newresults = new ResultsMulticlass(&fdsi);
00275       break;
00276    case Types::kNoAnalysisType:
00277       newresults = new ResultsClassification(&fdsi);
00278       break;
00279    case Types::kMaxAnalysisType:
00280       //Log() << kINFO << " GetResults("<<info<<") can't create new one." << Endl;
00281       return 0;
00282       break;
00283    }
00284 
00285    newresults->SetTreeType( type );
00286    fResults[t][resultsName] = newresults;
00287 
00288    //Log() << kINFO << " GetResults("<<info<<") builds new result." << Endl;
00289    return newresults;
00290 }
00291 //_______________________________________________________________________
00292 void TMVA::DataSet::DeleteResults( const TString & resultsName,
00293                                    Types::ETreeType type,
00294                                    Types::EAnalysisType /* analysistype */ ) 
00295 {
00296    // delete the results stored for this particulary 
00297    //      Method instance  (here appareantly called resultsName instead of MethodTitle
00298    //      Tree type (Training, testing etc..)
00299    //      Analysis Type (Classification, Multiclass, Regression etc..)
00300 
00301    if (fResults.size() == 0) return;
00302 
00303    if (UInt_t(type) > fResults.size()){
00304       Log()<<kFATAL<< "you asked for an Treetype (training/testing/...)"
00305            << " whose index " << type << " does not exist " << Endl;
00306    }
00307    std::map< TString, Results* >& resultsForType = fResults[UInt_t(type)];
00308    std::map< TString, Results* >::iterator it = resultsForType.find(resultsName);
00309    if (it!=resultsForType.end()) {
00310       Log() << kDEBUG << " Delete Results previous existing result:" << resultsName 
00311             << " of type " << type << Endl;
00312       delete it->second;
00313       resultsForType.erase(it->first);
00314    }else{
00315       Log() << kINFO << "could not fine Result class of " << resultsName 
00316             << " of type " << type << " which I should have deleted" << Endl;
00317    }
00318 }
00319 //_______________________________________________________________________
00320 void TMVA::DataSet::DivideTrainingSet( UInt_t blockNum )
00321 {
00322    // divide training set
00323    Int_t tOrg = TreeIndex(Types::kTrainingOriginal),tTrn = TreeIndex(Types::kTraining);
00324    // not changing anything ??
00325    if (fBlockBelongToTraining.size() == blockNum) return;
00326    // storing the original training vector
00327    if (fBlockBelongToTraining.size() == 1) {
00328       if (fEventCollection[tOrg] == 0)
00329          fEventCollection[tOrg]=new std::vector<TMVA::Event*>(fEventCollection[tTrn]->size());
00330       fEventCollection[tOrg]->clear();
00331       for (UInt_t i=0; i<fEventCollection[tTrn]->size(); i++)
00332          fEventCollection[tOrg]->push_back((*fEventCollection[tTrn])[i]);
00333       fClassEvents[tOrg] = fClassEvents[tTrn];
00334    }
00335    //reseting the event division vector
00336    fBlockBelongToTraining.clear();
00337    for (UInt_t i=0 ; i < blockNum ; i++) fBlockBelongToTraining.push_back(kTRUE);
00338 
00339    ApplyTrainingSetDivision();
00340 }
00341 
00342 //_______________________________________________________________________
00343 void TMVA::DataSet::ApplyTrainingSetDivision()
00344 {
00345    // apply division of data set
00346    Int_t tOrg = TreeIndex(Types::kTrainingOriginal), tTrn = TreeIndex(Types::kTraining), tVld = TreeIndex(Types::kValidation);
00347    fEventCollection[tTrn]->clear();
00348    if (fEventCollection[tVld]==0)
00349       fEventCollection[tVld] = new std::vector<TMVA::Event*>(fEventCollection[tOrg]->size());
00350    fEventCollection[tVld]->clear();
00351 
00352    //creating the new events collections, notice that the events that can't be evenly divided belong to the last event
00353    for (UInt_t i=0; i<fEventCollection[tOrg]->size(); i++) {
00354       if (fBlockBelongToTraining[i % fBlockBelongToTraining.size()])
00355          fEventCollection[tTrn]->push_back((*fEventCollection[tOrg])[i]);
00356       else
00357          fEventCollection[tVld]->push_back((*fEventCollection[tOrg])[i]);
00358    }
00359 }
00360 
00361 //_______________________________________________________________________
00362 void TMVA::DataSet::MoveTrainingBlock( Int_t blockInd,Types::ETreeType dest, Bool_t applyChanges )
00363 {
00364    // move training block 
00365    if (dest == Types::kValidation)
00366       fBlockBelongToTraining[blockInd]=kFALSE;
00367    else
00368       fBlockBelongToTraining[blockInd]=kTRUE;
00369    if (applyChanges) ApplyTrainingSetDivision();
00370 }
00371 
00372 //_______________________________________________________________________
00373 Long64_t TMVA::DataSet::GetNEvtSigTest()   
00374 { 
00375    // return number of signal test events in dataset
00376    return GetNClassEvents(Types::kTesting, fdsi.GetClassInfo("Signal")->GetNumber() ); 
00377 }
00378 
00379 //_______________________________________________________________________
00380 Long64_t TMVA::DataSet::GetNEvtBkgdTest()  
00381 { 
00382    // return number of background test events in dataset
00383    return GetNClassEvents(Types::kTesting, fdsi.GetClassInfo("Background")->GetNumber() ); 
00384 }
00385 
00386 //_______________________________________________________________________
00387 Long64_t TMVA::DataSet::GetNEvtSigTrain()  
00388 { 
00389    // return number of signal training events in dataset
00390    return GetNClassEvents(Types::kTraining, fdsi.GetClassInfo("Signal")->GetNumber() ); 
00391 }
00392 
00393 //_______________________________________________________________________
00394 Long64_t TMVA::DataSet::GetNEvtBkgdTrain() 
00395 { 
00396    // return number of background training events in dataset
00397    return GetNClassEvents(Types::kTraining, fdsi.GetClassInfo("Background")->GetNumber() ); 
00398 }
00399 
00400 //_______________________________________________________________________
00401 void TMVA::DataSet::InitSampling( Float_t fraction, Float_t weight, UInt_t seed  )
00402 {
00403    // initialize random or importance sampling
00404 
00405    // add a random generator if not yet present
00406    if (fSamplingRandom == 0 ) fSamplingRandom = new TRandom3( seed );
00407 
00408    // first, clear the lists
00409    std::vector< std::pair< Float_t, Long64_t >* > evtList;
00410    std::vector< std::pair< Float_t, Long64_t >* >::iterator it;
00411 
00412    Int_t treeIdx = TreeIndex( GetCurrentType() );
00413 
00414    if (fSamplingEventList.size() < UInt_t(treeIdx+1) ) fSamplingEventList.resize(treeIdx+1);
00415    if (fSamplingSelected.size() < UInt_t(treeIdx+1) ) fSamplingSelected.resize(treeIdx+1);
00416    for (it = fSamplingEventList.at(treeIdx).begin(); it != fSamplingEventList.at(treeIdx).end(); it++ ) delete (*it);
00417    fSamplingEventList.at(treeIdx).clear();
00418    fSamplingSelected.at(treeIdx).clear();
00419 
00420    if (fSampling.size() < UInt_t(treeIdx+1) )         fSampling.resize(treeIdx+1);
00421    if (fSamplingNEvents.size() < UInt_t(treeIdx+1) ) fSamplingNEvents.resize(treeIdx+1);
00422    if (fSamplingWeight.size() < UInt_t(treeIdx+1) )   fSamplingWeight.resize(treeIdx+1);
00423       
00424    if (fraction > 0.999999 || fraction < 0.0000001) {
00425       fSampling.at( treeIdx ) = false;
00426       fSamplingNEvents.at( treeIdx ) = 0;
00427       fSamplingWeight.at( treeIdx ) = 1.0;
00428       return;
00429    }
00430 
00431    // for the initialization, the sampling has to be turned off, afterwards we will turn it on
00432    fSampling.at( treeIdx )  = false;
00433 
00434    fSamplingNEvents.at( treeIdx ) = Int_t(fraction*GetNEvents());
00435    fSamplingWeight.at( treeIdx ) = weight;
00436 
00437    Long64_t nEvts = GetNEvents();
00438    fSamplingEventList.at( treeIdx ).reserve( nEvts );
00439    fSamplingSelected.at( treeIdx ).reserve( fSamplingNEvents.at(treeIdx) );
00440    for (Long64_t ievt=0; ievt<nEvts; ievt++) {
00441       std::pair<Float_t,Long64_t> *p = new std::pair<Float_t,Long64_t>(std::make_pair<Float_t,Long64_t>(1.0,ievt));
00442       fSamplingEventList.at( treeIdx ).push_back( p );
00443    }
00444 
00445    // now turn sampling on
00446    fSampling.at( treeIdx ) = true;
00447 }
00448 
00449 
00450 //_______________________________________________________________________
00451 void TMVA::DataSet::CreateSampling() const
00452 {
00453    // create an event sampling (random or importance sampling)
00454 
00455    Int_t treeIdx = TreeIndex( GetCurrentType() );
00456 
00457    if (!fSampling.at(treeIdx) ) return;
00458 
00459    if (fSamplingRandom == 0 )
00460       Log() << kFATAL
00461             << "no random generator present for creating a random/importance sampling (initialized?)" << Endl;
00462 
00463    // delete the previous selection
00464    fSamplingSelected.at(treeIdx).clear();
00465 
00466    // create a temporary event-list
00467    std::vector< std::pair< Float_t, Long64_t >* > evtList;
00468    std::vector< std::pair< Float_t, Long64_t >* >::iterator evtListIt;
00469 
00470    // some variables
00471    Float_t sumWeights = 0;
00472 
00473    // make a copy of the event-list
00474    evtList.assign( fSamplingEventList.at(treeIdx).begin(), fSamplingEventList.at(treeIdx).end() );
00475 
00476    // sum up all the weights (internal weights for importance sampling)
00477    for (evtListIt = evtList.begin(); evtListIt != evtList.end(); evtListIt++) {
00478       sumWeights += (*evtListIt)->first;
00479    }
00480    evtListIt = evtList.begin();
00481 
00482    // random numbers
00483    std::vector< Float_t > rnds;
00484    rnds.reserve(fSamplingNEvents.at(treeIdx));
00485 
00486    Float_t pos = 0;
00487    for (Int_t i = 0; i < fSamplingNEvents.at(treeIdx); i++) {
00488       pos = fSamplingRandom->Rndm()*sumWeights;
00489       rnds.push_back( pos );
00490    }
00491    
00492    // sort the random numbers
00493    std::sort(rnds.begin(),rnds.end());
00494    
00495    // select the events according to the random numbers
00496    std::vector< Float_t >::iterator rndsIt = rnds.begin();
00497    Float_t runningSum = 0.000000001;
00498    for (evtListIt = evtList.begin(); evtListIt != evtList.end();) {
00499       runningSum += (*evtListIt)->first;
00500       if (runningSum >= (*rndsIt)) {
00501          fSamplingSelected.at(treeIdx).push_back( (*evtListIt) );
00502          evtListIt = evtList.erase( evtListIt );
00503 
00504          rndsIt++;
00505          if (rndsIt == rnds.end() ) break;
00506       }else{
00507          evtListIt++;
00508       }
00509    }
00510 }
00511 
00512 //_______________________________________________________________________
00513 void TMVA::DataSet::EventResult( Bool_t successful, Long64_t evtNumber )
00514 {
00515    // increase the importance sampling weight of the event 
00516    // when not successful and decrease it when successful
00517 
00518 
00519    if (!fSampling.at(fCurrentTreeIdx)) return;
00520    if (fSamplingWeight.at(fCurrentTreeIdx) > 0.99999999999) return;
00521 
00522    Long64_t start = 0;
00523    Long64_t stop  = fSamplingEventList.at(fCurrentTreeIdx).size() -1;
00524    if (evtNumber >= 0) {
00525       start = evtNumber; 
00526       stop  = evtNumber;
00527    }
00528    for ( Long64_t iEvt = start; iEvt <= stop; iEvt++ ){
00529       if (Long64_t(fSamplingEventList.at(fCurrentTreeIdx).size()) < iEvt) {
00530          Log() << kWARNING << "event number (" << iEvt 
00531                << ") larger than number of sampled events (" 
00532                << fSamplingEventList.at(fCurrentTreeIdx).size() << " of tree " << fCurrentTreeIdx << ")" << Endl;
00533          return;
00534       }
00535       Float_t weight = fSamplingEventList.at(fCurrentTreeIdx).at( iEvt )->first;
00536       if (!successful) {
00537          //      weight /= (fSamplingWeight.at(fCurrentTreeIdx)/fSamplingEventList.at(fCurrentTreeIdx).size());
00538          weight /= fSamplingWeight.at(fCurrentTreeIdx);
00539          if (weight > 1.0 ) weight = 1.0;
00540       }else{
00541          //      weight *= (fSamplingWeight.at(fCurrentTreeIdx)/fSamplingEventList.at(fCurrentTreeIdx).size());
00542          weight *= fSamplingWeight.at(fCurrentTreeIdx);
00543       }
00544       fSamplingEventList.at(fCurrentTreeIdx).at( iEvt )->first = weight;
00545    }
00546 }
00547 
00548 
00549 //_______________________________________________________________________
00550 TTree* TMVA::DataSet::GetTree( Types::ETreeType type ) 
00551 { 
00552    // create the test/trainings tree with all the variables, the weights, the classes, the targets, the spectators, the MVA outputs
00553    
00554    Log() << kDEBUG << "GetTree(" << ( type==Types::kTraining ? "training" : "testing" ) << ")" << Endl;
00555 
00556    // the dataset does not hold the tree, this function returns a new tree everytime it is called
00557 
00558    if (type!=Types::kTraining && type!=Types::kTesting) return 0;
00559 
00560    Types::ETreeType savedType = GetCurrentType();
00561 
00562    SetCurrentType(type);
00563    const UInt_t t = TreeIndex(type);
00564    if (fResults.size() <= t) {
00565       Log() << kWARNING << "No results for treetype " << ( type==Types::kTraining ? "training" : "testing" ) 
00566             << " found. Size=" << fResults.size() << Endl;
00567    }
00568 
00569    // return number of background training events in dataset
00570    TString treeName( (type == Types::kTraining ? "TrainTree" : "TestTree" ) );
00571    TTree *tree = new TTree(treeName,treeName);
00572 
00573    Float_t *varVals = new Float_t[fdsi.GetNVariables()];
00574    Float_t *tgtVals = new Float_t[fdsi.GetNTargets()];
00575    Float_t *visVals = new Float_t[fdsi.GetNSpectators()];
00576 
00577    UInt_t cls;
00578    Float_t weight;
00579    //   TObjString *className = new TObjString();
00580    char *className = new char[40];
00581 
00582 
00583    //Float_t metVals[fResults.at(t).size()][Int_t(fdsi.GetNTargets()+1)];
00584    // replace by:  [Joerg]
00585    Float_t **metVals = new Float_t*[fResults.at(t).size()];
00586    for(UInt_t i=0; i<fResults.at(t).size(); i++ )
00587       metVals[i] = new Float_t[fdsi.GetNTargets()+fdsi.GetNClasses()];
00588 
00589    // create branches for event-variables
00590    tree->Branch( "classID", &cls, "classID/I" ); 
00591    tree->Branch( "className",(void*)className, "className/C" ); 
00592 
00593 
00594    // create all branches for the variables
00595    Int_t n = 0;
00596    for (std::vector<VariableInfo>::const_iterator itVars = fdsi.GetVariableInfos().begin(); 
00597         itVars != fdsi.GetVariableInfos().end(); itVars++) {
00598 
00599       // has to be changed to take care of types different than float: TODO
00600       tree->Branch( (*itVars).GetInternalName(), &varVals[n], (*itVars).GetInternalName()+TString("/F") ); 
00601       n++;
00602    }
00603    // create the branches for the targets
00604    n = 0;
00605    for (std::vector<VariableInfo>::const_iterator itTgts = fdsi.GetTargetInfos().begin(); 
00606         itTgts != fdsi.GetTargetInfos().end(); itTgts++) {
00607       // has to be changed to take care of types different than float: TODO
00608       tree->Branch( (*itTgts).GetInternalName(), &tgtVals[n], (*itTgts).GetInternalName()+TString("/F") ); 
00609       n++;
00610    }
00611    // create the branches for the spectator variables
00612    n = 0;
00613    for (std::vector<VariableInfo>::const_iterator itVis = fdsi.GetSpectatorInfos().begin(); 
00614         itVis != fdsi.GetSpectatorInfos().end(); itVis++) {
00615       // has to be changed to take care of types different than float: TODO
00616       tree->Branch( (*itVis).GetInternalName(), &visVals[n], (*itVis).GetInternalName()+TString("/F") ); 
00617       n++;
00618    }
00619 
00620    tree->Branch( "weight", &weight, "weight/F" );
00621 
00622    // create all the branches for the results
00623    n = 0;
00624    for (std::map< TString, Results* >::iterator itMethod = fResults.at(t).begin(); 
00625         itMethod != fResults.at(t).end(); itMethod++) {
00626 
00627 
00628       Log() << kDEBUG << "analysis type: " << (itMethod->second->GetAnalysisType()==Types::kRegression ? "Regression" :
00629                                                (itMethod->second->GetAnalysisType()==Types::kMulticlass ? "Multiclass" : "Classification" )) << Endl;
00630       
00631       if (itMethod->second->GetAnalysisType() == Types::kClassification) {
00632          // classification
00633          tree->Branch( itMethod->first, &(metVals[n][0]), itMethod->first + "/F" );
00634       } else if (itMethod->second->GetAnalysisType() == Types::kMulticlass) {
00635          // multiclass classification
00636          TString leafList("");
00637          for (UInt_t iCls = 0; iCls < fdsi.GetNClasses(); iCls++) {
00638             if (iCls > 0) leafList.Append( ":" );
00639             leafList.Append( fdsi.GetClassInfo( iCls )->GetName() );
00640             leafList.Append( "/F" );
00641          }
00642          Log() << kDEBUG << "itMethod->first " << itMethod->first <<  "    LEAFLIST: " 
00643                << leafList << "    itMethod->second " << itMethod->second <<  Endl;
00644          tree->Branch( itMethod->first, (metVals[n]), leafList );
00645       } else if (itMethod->second->GetAnalysisType() == Types::kRegression) {
00646          // regression
00647          TString leafList("");
00648          for (UInt_t iTgt = 0; iTgt < fdsi.GetNTargets(); iTgt++) {
00649             if (iTgt > 0) leafList.Append( ":" );
00650             leafList.Append( fdsi.GetTargetInfo( iTgt ).GetInternalName() );
00651             //            leafList.Append( fdsi.GetTargetInfo( iTgt ).GetLabel() );
00652             leafList.Append( "/F" );
00653          }
00654          Log() << kDEBUG << "itMethod->first " << itMethod->first <<  "    LEAFLIST: " 
00655                << leafList << "    itMethod->second " << itMethod->second <<  Endl;
00656          tree->Branch( itMethod->first, (metVals[n]), leafList );
00657       } else {
00658          Log() << kWARNING << "Unknown analysis type for result found when writing TestTree." << Endl;
00659       }
00660       n++;
00661 
00662    }
00663 
00664    // loop through all the events
00665    for (Long64_t iEvt = 0; iEvt < GetNEvents( type ); iEvt++) {
00666       // write the event-variables
00667       const Event* ev = GetEvent( iEvt );
00668 
00669       // write the classnumber and the classname
00670       cls = ev->GetClass();
00671       weight = ev->GetWeight();
00672       TString tmp = fdsi.GetClassInfo( cls )->GetName();
00673       for (Int_t itmp = 0; itmp < tmp.Sizeof(); itmp++) {
00674          className[itmp] = tmp(itmp);
00675          className[itmp+1] = 0;
00676       }
00677 
00678       // write the variables, targets and spectator variables
00679       for (UInt_t ivar = 0; ivar < ev->GetNVariables();   ivar++) varVals[ivar] = ev->GetValue( ivar );
00680       for (UInt_t itgt = 0; itgt < ev->GetNTargets();     itgt++) tgtVals[itgt] = ev->GetTarget( itgt );
00681       for (UInt_t ivis = 0; ivis < ev->GetNSpectators();  ivis++) visVals[ivis] = ev->GetSpectator( ivis );
00682 
00683 
00684       // loop through all the results and write the branches
00685       n=0;
00686       for (std::map<TString, Results*>::iterator itMethod = fResults.at(t).begin();
00687            itMethod != fResults.at(t).end(); itMethod++) {
00688 
00689          Results* results = itMethod->second;
00690          const std::vector< Float_t >& vals = results->operator[](iEvt);
00691 
00692          if (itMethod->second->GetAnalysisType() == Types::kClassification) {
00693             // classification
00694             metVals[n][0] = vals[0];
00695          }
00696          else if (itMethod->second->GetAnalysisType() == Types::kMulticlass) {
00697             // multiclass classification
00698             for (UInt_t nCls = 0, nClsEnd=fdsi.GetNClasses(); nCls < nClsEnd; nCls++) {
00699                Float_t val = vals.at(nCls);
00700                metVals[n][nCls] = val;
00701             }
00702          }
00703          else if (itMethod->second->GetAnalysisType() == Types::kRegression) {
00704             // regression
00705             for (UInt_t nTgts = 0; nTgts < fdsi.GetNTargets(); nTgts++) {
00706                Float_t val = vals.at(nTgts);
00707                metVals[n][nTgts] = val;
00708             }
00709          }
00710          n++;
00711       }
00712       // fill the variables into the tree
00713       tree->Fill();
00714    }
00715 
00716    Log() << kINFO << "Created tree '" << tree->GetName() << "' with " << tree->GetEntries() << " events" << Endl;
00717 
00718    SetCurrentType(savedType);
00719 
00720    delete[] varVals;
00721    delete[] tgtVals;
00722    delete[] visVals;
00723 
00724    for(UInt_t i=0; i<fResults.at(t).size(); i++ )
00725       delete[] metVals[i];
00726    delete[] metVals;
00727 
00728    delete[] className;
00729 
00730    return tree;
00731 }
00732 

Generated on Tue Jul 5 15:16:44 2011 for ROOT_528-00b_version by  doxygen 1.5.1