DataSetFactory.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DataSetFactory.cxx 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Eckhard von Toerne, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : DataSetFactory                                                        *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header for description)                               *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Peter Speckmayer <Peter.Speckmayer@cern.ch> - CERN, Switzerland           *
00016  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00017  *      Eckhard von Toerne <evt@physik.uni-bonn.de>  - U. of Bonn, Germany        *
00018  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00019  *                                                                                *
00020  * Copyright (c) 2009:                                                            *
00021  *      CERN, Switzerland                                                         *
00022  *      MPI-K Heidelberg, Germany                                                 *
00023  *      U. of Bonn, Germany                                                       *
00024  * Redistribution and use in source and binary forms, with or without             *
00025  * modification, are permitted according to the terms listed in LICENSE           *
00026  * (http://tmva.sourceforge.net/LICENSE)                                          *
00027  **********************************************************************************/
00028 
00029 #include <assert.h>
00030 
00031 #include <map>
00032 #include <vector>
00033 #include <iomanip>
00034 #include <iostream>
00035 
00036 #include <algorithm>
00037 #include <functional>
00038 #include <numeric>
00039 
00040 #include "TMVA/DataSetFactory.h"
00041 
00042 #include "TEventList.h"
00043 #include "TFile.h"
00044 #include "TH1.h"
00045 #include "TH2.h"
00046 #include "TProfile.h"
00047 #include "TRandom3.h"
00048 #include "TMatrixF.h"
00049 #include "TVectorF.h"
00050 #include "TMath.h"
00051 #include "TROOT.h"
00052 
00053 #ifndef ROOT_TMVA_MsgLogger
00054 #include "TMVA/MsgLogger.h"
00055 #endif
00056 #ifndef ROOT_TMVA_Configurable
00057 #include "TMVA/Configurable.h"
00058 #endif
00059 #ifndef ROOT_TMVA_VariableIdentityTransform
00060 #include "TMVA/VariableIdentityTransform.h"
00061 #endif
00062 #ifndef ROOT_TMVA_VariableDecorrTransform
00063 #include "TMVA/VariableDecorrTransform.h"
00064 #endif
00065 #ifndef ROOT_TMVA_VariablePCATransform
00066 #include "TMVA/VariablePCATransform.h"
00067 #endif
00068 #ifndef ROOT_TMVA_DataSet
00069 #include "TMVA/DataSet.h"
00070 #endif
00071 #ifndef ROOT_TMVA_DataSetInfo
00072 #include "TMVA/DataSetInfo.h"
00073 #endif
00074 #ifndef ROOT_TMVA_DataInputHandler
00075 #include "TMVA/DataInputHandler.h"
00076 #endif
00077 #ifndef ROOT_TMVA_Event
00078 #include "TMVA/Event.h"
00079 #endif
00080 
00081 using namespace std;
00082 
00083 TMVA::DataSetFactory* TMVA::DataSetFactory::fgInstance = 0;
00084 
00085 namespace TMVA {
00086    // calculate the largest common divider
00087    // this function is not happy if numbers are negative!
00088    Int_t LargestCommonDivider(Int_t a, Int_t b) 
00089    {
00090       if (a<b) {Int_t tmp = a; a=b; b=tmp; } // achieve a>=b
00091       if (b==0) return a;
00092       Int_t fullFits = a/b;
00093       return LargestCommonDivider(b,a-b*fullFits);
00094    }
00095 }
00096 
00097 //_______________________________________________________________________
00098 TMVA::DataSetFactory::DataSetFactory() :
00099    fVerbose(kFALSE),
00100    fVerboseLevel(TString("Info")),
00101    fCurrentTree(0),
00102    fCurrentEvtIdx(0),
00103    fInputFormulas(0),
00104    fLogger( new MsgLogger("DataSetFactory", kINFO) )
00105 {
00106    // constructor
00107 }
00108 
00109 //_______________________________________________________________________
00110 TMVA::DataSetFactory::~DataSetFactory() 
00111 {
00112    // destructor
00113    std::vector<TTreeFormula*>::const_iterator formIt;
00114 
00115    for (formIt = fInputFormulas.begin()    ; formIt!=fInputFormulas.end()    ; formIt++) if (*formIt) delete *formIt;
00116    for (formIt = fTargetFormulas.begin()   ; formIt!=fTargetFormulas.end()   ; formIt++) if (*formIt) delete *formIt;
00117    for (formIt = fCutFormulas.begin()      ; formIt!=fCutFormulas.end()      ; formIt++) if (*formIt) delete *formIt;
00118    for (formIt = fWeightFormula.begin()    ; formIt!=fWeightFormula.end()    ; formIt++) if (*formIt) delete *formIt;
00119    for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); formIt++) if (*formIt) delete *formIt;
00120 
00121    delete fLogger;
00122 }
00123 
00124 //_______________________________________________________________________
00125 TMVA::DataSet* TMVA::DataSetFactory::CreateDataSet( TMVA::DataSetInfo& dsi, TMVA::DataInputHandler& dataInput ) 
00126 {
00127    // steering the creation of a new dataset
00128 
00129    // build the first dataset from the data input
00130    DataSet * ds = BuildInitialDataSet( dsi, dataInput );
00131 
00132    if (ds->GetNEvents() > 1) {
00133       CalcMinMax(ds,dsi);
00134       
00135       // from the the final dataset build the correlation matrix
00136       for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
00137          const TString className = dsi.GetClassInfo(cl)->GetName();
00138          dsi.SetCorrelationMatrix( className, CalcCorrelationMatrix( ds, cl ) );
00139          dsi.PrintCorrelationMatrix( className );
00140       }
00141       Log() << kINFO << " " << Endl;
00142    }
00143    return ds;
00144 }
00145 
00146 //_______________________________________________________________________
00147 TMVA::DataSet* TMVA::DataSetFactory::BuildDynamicDataSet( TMVA::DataSetInfo& dsi ) 
00148 {
00149    Log() << kDEBUG << "Build DataSet consisting of one Event with dynamically changing variables" << Endl;
00150    DataSet* ds = new DataSet(dsi);
00151 
00152    // create a DataSet with one Event which uses dynamic variables (pointers to variables)
00153    if(dsi.GetNClasses()==0){
00154       dsi.AddClass( "data" );
00155       dsi.GetClassInfo( "data" )->SetNumber(0);
00156    }
00157    
00158    std::vector<Float_t*>* evdyn = new std::vector<Float_t*>(0);
00159 
00160    std::vector<VariableInfo>& varinfos = dsi.GetVariableInfos();
00161    std::vector<VariableInfo>::iterator it = varinfos.begin();
00162    for (;it!=varinfos.end();it++) evdyn->push_back( (Float_t*)(*it).GetExternalLink() );
00163 
00164    std::vector<VariableInfo>& spectatorinfos = dsi.GetSpectatorInfos();
00165    it = spectatorinfos.begin();
00166    for (;it!=spectatorinfos.end();it++) evdyn->push_back( (Float_t*)(*it).GetExternalLink() );
00167 
00168    TMVA::Event * ev = new Event((const std::vector<Float_t*>*&)evdyn, varinfos.size());
00169    std::vector<Event*>* newEventVector = new std::vector<Event*>;
00170    newEventVector->push_back(ev);
00171 
00172    ds->SetEventCollection(newEventVector, Types::kTraining);
00173    ds->SetCurrentType( Types::kTraining );
00174    ds->SetCurrentEvent( 0 );
00175 
00176    return ds;
00177 }
00178 
00179 
00180 //_______________________________________________________________________
00181 TMVA::DataSet* TMVA::DataSetFactory::BuildInitialDataSet( DataSetInfo& dsi, DataInputHandler& dataInput ) 
00182 {
00183    // if no entries, than create a DataSet with one Event which uses dynamic variables (pointers to variables)
00184    if (dataInput.GetEntries()==0) return BuildDynamicDataSet( dsi );
00185    // ------------------------------------------------------------------------------------
00186 
00187    // register the classes in the datasetinfo-object
00188    // information comes from the trees in the dataInputHandler-object
00189    std::vector< TString >* classList = dataInput.GetClassList();
00190    for (std::vector<TString>::iterator it = classList->begin(); it< classList->end(); it++) {
00191       dsi.AddClass( (*it) );
00192    }
00193    delete classList;
00194 
00195    TString normMode;
00196    TString splitMode;
00197    TString mixMode;
00198    UInt_t splitSeed;
00199 
00200    // ======= build event-vector tentative new ordering =================================
00201    
00202    TMVA::EventVectorOfClassesOfTreeType tmpEventVector;
00203    TMVA::NumberPerClassOfTreeType       nTrainTestEvents;
00204 
00205    InitOptions     ( dsi, nTrainTestEvents, normMode, splitSeed, splitMode , mixMode );
00206    BuildEventVector( dsi, dataInput, tmpEventVector );
00207       
00208    DataSet* ds = MixEvents( dsi, tmpEventVector, nTrainTestEvents, splitMode, mixMode, normMode, splitSeed);
00209 
00210    const Bool_t showCollectedOutput = kFALSE;
00211    if (showCollectedOutput) {
00212       Int_t maxL = dsi.GetClassNameMaxLength();
00213       Log() << kINFO << "Collected:" << Endl;
00214       for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00215          Log() << kINFO << "    " 
00216                << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00217                << " training entries: " << ds->GetNClassEvents( 0, cl ) << Endl;
00218          Log() << kINFO << "    " 
00219                << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00220                << " testing  entries: " << ds->GetNClassEvents( 1, cl ) << Endl;      
00221       }
00222       Log() << kINFO << " " << Endl;
00223    }
00224 
00225    return ds;
00226 }
00227 
00228 //_______________________________________________________________________
00229 Bool_t TMVA::DataSetFactory::CheckTTreeFormula( TTreeFormula* ttf, const TString& expression, Bool_t& hasDollar )
00230 { 
00231    // checks a TTreeFormula for problems
00232    Bool_t worked = kTRUE;
00233       
00234    if( ttf->GetNdim() <= 0 )
00235       Log() << kFATAL << "Expression " << expression.Data() << " could not be resolved to a valid formula. " << Endl;
00236    //    if( ttf->GetNcodes() == 0 ){
00237    //       Log() << kWARNING << "Expression: " << expression.Data() << " does not appear to depend on any TTree variable --> please check spelling" << Endl;
00238    //       worked = kFALSE;
00239    //    }
00240    if( ttf->GetNdata() == 0 ){
00241       Log() << kWARNING << "Expression: " << expression.Data() 
00242             << " does not provide data for this event. "
00243             << "This event is not taken into account. --> please check if you use as a variable "
00244             << "an entry of an array which is not filled for some events "
00245             << "(e.g. arr[4] when arr has only 3 elements)." << Endl;
00246       Log() << kWARNING << "If you want to take the event into account you can do something like: "
00247             << "\"Alt$(arr[4],0)\" where in cases where arr doesn't have a 4th element, "
00248             << " 0 is taken as an alternative." << Endl;
00249       worked = kFALSE;
00250    }
00251    if( expression.Contains("$") ) hasDollar = kTRUE;
00252    return worked;
00253 }
00254 
00255 //_______________________________________________________________________
00256 void TMVA::DataSetFactory::ChangeToNewTree( TreeInfo& tinfo, const DataSetInfo & dsi )
00257 { 
00258    // While the data gets copied into the local training and testing
00259    // trees, the input tree can change (for intance when changing from
00260    // signal to background tree, or using TChains as input) The
00261    // TTreeFormulas, that hold the input expressions need to be
00262    // reassociated with the new tree, which is done here
00263 
00264    TTree *tr = tinfo.GetTree()->GetTree();
00265 
00266    tr->SetBranchStatus("*",1);
00267 
00268    Bool_t hasDollar = kFALSE;
00269 
00270    // 1) the input variable formulas
00271    Log() << kDEBUG << "transform input variables" << Endl;
00272    std::vector<TTreeFormula*>::const_iterator formIt, formItEnd;
00273    for (formIt = fInputFormulas.begin(), formItEnd=fInputFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00274    fInputFormulas.clear();
00275    TTreeFormula* ttf = 0;
00276 
00277    for (UInt_t i=0; i<dsi.GetNVariables(); i++) {
00278       ttf = new TTreeFormula( Form( "Formula%s", dsi.GetVariableInfo(i).GetInternalName().Data() ),
00279                               dsi.GetVariableInfo(i).GetExpression().Data(), tr );
00280       CheckTTreeFormula( ttf, dsi.GetVariableInfo(i).GetExpression(), hasDollar );
00281       fInputFormulas.push_back( ttf );
00282    }
00283 
00284    //
00285    // targets
00286    //
00287    Log() << kDEBUG << "transform regression targets" << Endl;
00288    for (formIt = fTargetFormulas.begin(), formItEnd = fTargetFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00289    fTargetFormulas.clear();
00290    for (UInt_t i=0; i<dsi.GetNTargets(); i++) {
00291       ttf = new TTreeFormula( Form( "Formula%s", dsi.GetTargetInfo(i).GetInternalName().Data() ),
00292                               dsi.GetTargetInfo(i).GetExpression().Data(), tr );
00293       CheckTTreeFormula( ttf, dsi.GetTargetInfo(i).GetExpression(), hasDollar );
00294       fTargetFormulas.push_back( ttf );
00295    }
00296 
00297    //
00298    // spectators
00299    //
00300    Log() << kDEBUG << "transform spectator variables" << Endl;
00301    for (formIt = fSpectatorFormulas.begin(), formItEnd = fSpectatorFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00302    fSpectatorFormulas.clear();
00303    for (UInt_t i=0; i<dsi.GetNSpectators(); i++) {
00304       ttf = new TTreeFormula( Form( "Formula%s", dsi.GetSpectatorInfo(i).GetInternalName().Data() ),
00305                               dsi.GetSpectatorInfo(i).GetExpression().Data(), tr );
00306       CheckTTreeFormula( ttf, dsi.GetSpectatorInfo(i).GetExpression(), hasDollar );
00307       fSpectatorFormulas.push_back( ttf );
00308    }
00309 
00310    //
00311    // the cuts (one per class, if non-existent: formula pointer = 0)
00312    //
00313    Log() << kDEBUG << "transform cuts" << Endl;
00314    for (formIt = fCutFormulas.begin(), formItEnd = fCutFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00315    fCutFormulas.clear();
00316    for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
00317       const TCut& tmpCut = dsi.GetClassInfo(clIdx)->GetCut();
00318       const TString tmpCutExp(tmpCut.GetTitle());
00319       ttf = 0;
00320       if (tmpCutExp!="") {
00321          ttf = new TTreeFormula( Form("CutClass%i",clIdx), tmpCutExp, tr );
00322          Bool_t worked = CheckTTreeFormula( ttf, tmpCutExp, hasDollar );
00323          if( !worked ){
00324             Log() << kWARNING << "Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
00325                   << "\" cut \"" << dsi.GetClassInfo(clIdx)->GetCut() << Endl;
00326          }
00327       }
00328       fCutFormulas.push_back( ttf );
00329    }
00330 
00331    //
00332    // the weights (one per class, if non-existent: formula pointer = 0)
00333    //
00334    Log() << kDEBUG << "transform weights" << Endl;
00335    for (formIt = fWeightFormula.begin(), formItEnd = fWeightFormula.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00336    fWeightFormula.clear();
00337    for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
00338       const TString tmpWeight = dsi.GetClassInfo(clIdx)->GetWeight();
00339 
00340       if (dsi.GetClassInfo(clIdx)->GetName() != tinfo.GetClassName() ) { // if the tree is of another class
00341          fWeightFormula.push_back( 0 );
00342          continue; 
00343       }
00344 
00345       ttf = 0;
00346       if (tmpWeight!="") {
00347          ttf = new TTreeFormula( "FormulaWeight", tmpWeight, tr );
00348          Bool_t worked = CheckTTreeFormula( ttf, tmpWeight, hasDollar );
00349          if( !worked ){
00350             Log() << kWARNING << "Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
00351                   << "\" weight \"" << dsi.GetClassInfo(clIdx)->GetWeight() << Endl;
00352          }
00353       }
00354       else {
00355          ttf = 0;
00356       }
00357       fWeightFormula.push_back( ttf );
00358    }
00359    Log() << kDEBUG << "enable branches" << Endl;
00360    // now enable only branches that are needed in any input formula, target, cut, weight
00361 
00362    if (!hasDollar) {
00363       tr->SetBranchStatus("*",0);
00364       Log() << kDEBUG << "enable branches: input variables" << Endl;
00365       // input vars
00366       for (formIt = fInputFormulas.begin(); formIt!=fInputFormulas.end(); formIt++) {
00367          ttf = *formIt;
00368          for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++) {
00369             tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00370          }
00371       }
00372       // targets
00373       Log() << kDEBUG << "enable branches: targets" << Endl;
00374       for (formIt = fTargetFormulas.begin(); formIt!=fTargetFormulas.end(); formIt++) {
00375          ttf = *formIt;
00376          for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00377             tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00378       }
00379       // spectators
00380       Log() << kDEBUG << "enable branches: spectators" << Endl;
00381       for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); formIt++) {
00382          ttf = *formIt;
00383          for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00384             tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00385       }
00386       // cuts
00387       Log() << kDEBUG << "enable branches: cuts" << Endl;
00388       for (formIt = fCutFormulas.begin(); formIt!=fCutFormulas.end(); formIt++) {
00389          ttf = *formIt;
00390          if (!ttf) continue;
00391          for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00392             tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00393       }
00394       // weights
00395       Log() << kDEBUG << "enable branches: weights" << Endl;
00396       for (formIt = fWeightFormula.begin(); formIt!=fWeightFormula.end(); formIt++) {
00397          ttf = *formIt;
00398          if (!ttf) continue;
00399          for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00400             tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00401       }
00402    }
00403    Log() << kDEBUG << "tree initialized" << Endl;
00404    return;
00405 }
00406 
00407 //_______________________________________________________________________
00408 void TMVA::DataSetFactory::CalcMinMax( DataSet* ds, TMVA::DataSetInfo& dsi )
00409 {
00410    // compute covariance matrix
00411    const UInt_t nvar  = ds->GetNVariables();
00412    const UInt_t ntgts = ds->GetNTargets();
00413    const UInt_t nvis  = ds->GetNSpectators();
00414 
00415    Float_t *min = new Float_t[nvar];
00416    Float_t *max = new Float_t[nvar];
00417    Float_t *tgmin = new Float_t[ntgts];
00418    Float_t *tgmax = new Float_t[ntgts];
00419    Float_t *vmin  = new Float_t[nvis];
00420    Float_t *vmax  = new Float_t[nvis];
00421 
00422    for (UInt_t ivar=0; ivar<nvar ; ivar++) {   min[ivar] = FLT_MAX;   max[ivar] = -FLT_MAX; }
00423    for (UInt_t ivar=0; ivar<ntgts; ivar++) { tgmin[ivar] = FLT_MAX; tgmax[ivar] = -FLT_MAX; }
00424    for (UInt_t ivar=0; ivar<nvis;  ivar++) {  vmin[ivar] = FLT_MAX;  vmax[ivar] = -FLT_MAX; }
00425 
00426    // perform event loop
00427 
00428    for (Int_t i=0; i<ds->GetNEvents(); i++) {
00429       Event * ev = ds->GetEvent(i);
00430       for (UInt_t ivar=0; ivar<nvar; ivar++) {
00431          Double_t v = ev->GetValue(ivar);
00432          if (v<min[ivar]) min[ivar] = v;
00433          if (v>max[ivar]) max[ivar] = v;
00434       }
00435       for (UInt_t itgt=0; itgt<ntgts; itgt++) {
00436          Double_t v = ev->GetTarget(itgt);
00437          if (v<tgmin[itgt]) tgmin[itgt] = v;
00438          if (v>tgmax[itgt]) tgmax[itgt] = v;
00439       }
00440       for (UInt_t ivis=0; ivis<nvis; ivis++) {
00441          Double_t v = ev->GetSpectator(ivis);
00442          if (v<vmin[ivis]) vmin[ivis] = v;
00443          if (v>vmax[ivis]) vmax[ivis] = v;
00444       }
00445    }
00446 
00447    for (UInt_t ivar=0; ivar<nvar; ivar++) {
00448       dsi.GetVariableInfo(ivar).SetMin(min[ivar]);
00449       dsi.GetVariableInfo(ivar).SetMax(max[ivar]);
00450       if( TMath::Abs(max[ivar]-min[ivar]) <= FLT_MIN )
00451          Log() << kFATAL << "Variable " << dsi.GetVariableInfo(ivar).GetExpression().Data() << " is constant. Please remove the variable." << Endl;
00452    }
00453    for (UInt_t ivar=0; ivar<ntgts; ivar++) {
00454       dsi.GetTargetInfo(ivar).SetMin(tgmin[ivar]);
00455       dsi.GetTargetInfo(ivar).SetMax(tgmax[ivar]);
00456       if( TMath::Abs(tgmax[ivar]-tgmin[ivar]) <= FLT_MIN )
00457          Log() << kFATAL << "Target " << dsi.GetTargetInfo(ivar).GetExpression().Data() << " is constant. Please remove the variable." << Endl;
00458    }
00459    for (UInt_t ivar=0; ivar<nvis; ivar++) {
00460       dsi.GetSpectatorInfo(ivar).SetMin(vmin[ivar]);
00461       dsi.GetSpectatorInfo(ivar).SetMax(vmax[ivar]);
00462       //       if( TMath::Abs(vmax[ivar]-vmin[ivar]) <= FLT_MIN )
00463       //          Log() << kWARNING << "Spectator variable " << dsi.GetSpectatorInfo(ivar).GetExpression().Data() << " is constant." << Endl;
00464    }
00465    delete [] min;
00466    delete [] max;
00467    delete [] tgmin;
00468    delete [] tgmax;
00469    delete [] vmin;
00470    delete [] vmax;
00471 }
00472 
00473 //_______________________________________________________________________
00474 TMatrixD* TMVA::DataSetFactory::CalcCorrelationMatrix( DataSet* ds, const UInt_t classNumber )
00475 {
00476    // computes correlation matrix for variables "theVars" in tree;
00477    // "theType" defines the required event "type" 
00478    // ("type" variable must be present in tree)
00479 
00480    // first compute variance-covariance
00481    TMatrixD* mat = CalcCovarianceMatrix( ds, classNumber );
00482 
00483    // now the correlation
00484    UInt_t nvar = ds->GetNVariables(), ivar, jvar;
00485 
00486    for (ivar=0; ivar<nvar; ivar++) {
00487       for (jvar=0; jvar<nvar; jvar++) {
00488          if (ivar != jvar) {
00489             Double_t d = (*mat)(ivar, ivar)*(*mat)(jvar, jvar);
00490             if (d > 0) (*mat)(ivar, jvar) /= sqrt(d);
00491             else {
00492                Log() << kWARNING << "<GetCorrelationMatrix> Zero variances for variables "
00493                      << "(" << ivar << ", " << jvar << ") = " << d                   
00494                      << Endl;
00495                (*mat)(ivar, jvar) = 0;
00496             }
00497          }
00498       }
00499    }
00500 
00501    for (ivar=0; ivar<nvar; ivar++) (*mat)(ivar, ivar) = 1.0;
00502 
00503    return mat;
00504 }
00505 
00506 //_______________________________________________________________________
00507 TMatrixD* TMVA::DataSetFactory::CalcCovarianceMatrix( DataSet * ds, const UInt_t classNumber )
00508 {
00509    // compute covariance matrix
00510 
00511    UInt_t nvar = ds->GetNVariables();
00512    UInt_t ivar = 0, jvar = 0;
00513 
00514    TMatrixD* mat = new TMatrixD( nvar, nvar );
00515 
00516    // init matrices
00517    TVectorD vec(nvar);
00518    TMatrixD mat2(nvar, nvar);      
00519    for (ivar=0; ivar<nvar; ivar++) {
00520       vec(ivar) = 0;
00521       for (jvar=0; jvar<nvar; jvar++) mat2(ivar, jvar) = 0;
00522    }
00523 
00524    // perform event loop
00525    Double_t ic = 0;
00526    for (Int_t i=0; i<ds->GetNEvents(); i++) {
00527 
00528       Event * ev = ds->GetEvent(i);
00529       if (ev->GetClass() != classNumber ) continue;
00530 
00531       Double_t weight = ev->GetWeight();
00532       ic += weight; // count used events
00533       
00534       for (ivar=0; ivar<nvar; ivar++) {
00535          
00536          Double_t xi = ev->GetValue(ivar);
00537          vec(ivar) += xi*weight;
00538          mat2(ivar, ivar) += (xi*xi*weight);
00539          
00540          for (jvar=ivar+1; jvar<nvar; jvar++) {
00541             Double_t xj =  ev->GetValue(jvar);
00542             mat2(ivar, jvar) += (xi*xj*weight);
00543          }
00544       }
00545    }
00546 
00547    for (ivar=0; ivar<nvar; ivar++)
00548       for (jvar=ivar+1; jvar<nvar; jvar++)
00549          mat2(jvar, ivar) = mat2(ivar, jvar); // symmetric matrix
00550 
00551 
00552    // variance-covariance
00553    for (ivar=0; ivar<nvar; ivar++) {
00554       for (jvar=0; jvar<nvar; jvar++) {
00555          (*mat)(ivar, jvar) = mat2(ivar, jvar)/ic - vec(ivar)*vec(jvar)/(ic*ic);
00556       }
00557    }
00558 
00559    return mat;
00560 }
00561 
00562 // --------------------------------------- new versions
00563 
00564 //_______________________________________________________________________
00565 void TMVA::DataSetFactory::InitOptions( TMVA::DataSetInfo& dsi, 
00566                                         TMVA::NumberPerClassOfTreeType& nTrainTestEvents, 
00567                                         TString& normMode, UInt_t& splitSeed, 
00568                                         TString& splitMode,
00569                                         TString& mixMode  ) 
00570 {
00571    // the dataset splitting
00572    Configurable splitSpecs( dsi.GetSplitOptions() );
00573    splitSpecs.SetConfigName("DataSetFactory");
00574    splitSpecs.SetConfigDescription( "Configuration options given in the \"PrepareForTrainingAndTesting\" call; these options define the creation of the data sets used for training and expert validation by TMVA" );
00575 
00576    splitMode = "Random";    // the splitting mode
00577    splitSpecs.DeclareOptionRef( splitMode, "SplitMode",
00578                                 "Method of picking training and testing events (default: random)" );
00579    splitSpecs.AddPreDefVal(TString("Random"));
00580    splitSpecs.AddPreDefVal(TString("Alternate"));
00581    splitSpecs.AddPreDefVal(TString("Block"));
00582 
00583    mixMode = "SameAsSplitMode";    // the splitting mode
00584    splitSpecs.DeclareOptionRef( mixMode, "MixMode",
00585                                 "Method of mixing events of differnt classes into one dataset (default: SameAsSplitMode)" );
00586    splitSpecs.AddPreDefVal(TString("SameAsSplitMode"));
00587    splitSpecs.AddPreDefVal(TString("Random"));
00588    splitSpecs.AddPreDefVal(TString("Alternate"));
00589    splitSpecs.AddPreDefVal(TString("Block"));
00590    
00591    splitSeed = 100;
00592    splitSpecs.DeclareOptionRef( splitSeed, "SplitSeed",
00593                                 "Seed for random event shuffling" );   
00594 
00595    normMode = "NumEvents";  // the weight normalisation modes
00596    splitSpecs.DeclareOptionRef( normMode, "NormMode",
00597                                 "Overall renormalisation of event-by-event weights (NumEvents: average weight of 1 per event, independently for signal and background; EqualNumEvents: average weight of 1 per event for signal, and sum of weights for background equal to sum of weights for signal)" );
00598    splitSpecs.AddPreDefVal(TString("None"));
00599    splitSpecs.AddPreDefVal(TString("NumEvents"));
00600    splitSpecs.AddPreDefVal(TString("EqualNumEvents"));
00601 
00602    // the number of events
00603 
00604    // initialization
00605    nTrainTestEvents.insert( TMVA::NumberPerClassOfTreeType::value_type( Types::kTraining, TMVA::NumberPerClass( dsi.GetNClasses() ) ) );
00606    nTrainTestEvents.insert( TMVA::NumberPerClassOfTreeType::value_type( Types::kTesting,  TMVA::NumberPerClass( dsi.GetNClasses() ) ) );
00607 
00608    // fill in the numbers
00609    for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00610       nTrainTestEvents[Types::kTraining].at(cl)  = 0;
00611       nTrainTestEvents[Types::kTesting].at(cl)   = 0;
00612 
00613       TString clName = dsi.GetClassInfo(cl)->GetName();
00614       TString titleTrain =  TString().Format("Number of training events of class %s (default: 0 = all)",clName.Data()).Data();
00615       TString titleTest  =  TString().Format("Number of test events of class %s (default: 0 = all)",clName.Data()).Data();
00616 
00617       splitSpecs.DeclareOptionRef( nTrainTestEvents[Types::kTraining].at(cl) , TString("nTrain_")+clName, titleTrain );
00618       splitSpecs.DeclareOptionRef( nTrainTestEvents[Types::kTesting].at(cl)  , TString("nTest_")+clName , titleTest  );
00619    }
00620 
00621    splitSpecs.DeclareOptionRef( fVerbose, "V", "Verbosity (default: true)" );
00622 
00623    splitSpecs.DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
00624    splitSpecs.AddPreDefVal(TString("Debug"));
00625    splitSpecs.AddPreDefVal(TString("Verbose"));
00626    splitSpecs.AddPreDefVal(TString("Info"));
00627 
00628    splitSpecs.ParseOptions();
00629    splitSpecs.CheckForUnusedOptions();
00630 
00631    // output logging verbosity
00632    if (Verbose()) fLogger->SetMinType( kVERBOSE );   
00633    if (fVerboseLevel.CompareTo("Debug")   ==0) fLogger->SetMinType( kDEBUG );
00634    if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
00635    if (fVerboseLevel.CompareTo("Info")    ==0) fLogger->SetMinType( kINFO );
00636 
00637    // put all to upper case
00638    splitMode.ToUpper(); mixMode.ToUpper(); normMode.ToUpper();
00639    // adjust mixmode if same as splitmode option has been set
00640    Log() << kINFO << "Splitmode is: \"" << splitMode << "\" the mixmode is: \"" << mixMode << "\"" << Endl;
00641    if (mixMode=="SAMEASSPLITMODE") mixMode = splitMode;
00642    else if (mixMode!=splitMode) 
00643       Log() << kINFO << "DataSet splitmode="<<splitMode
00644             <<" differs from mixmode="<<mixMode<<Endl;
00645 }
00646 
00647 
00648 //_______________________________________________________________________
00649 void  TMVA::DataSetFactory::BuildEventVector( TMVA::DataSetInfo& dsi, 
00650                                               TMVA::DataInputHandler& dataInput, 
00651                                               TMVA::EventVectorOfClassesOfTreeType& tmpEventVector )
00652 {
00653    // build empty event vectors
00654    // distributes events between kTraining/kTesting/kMaxTreeType
00655    
00656    tmpEventVector.insert( std::make_pair(Types::kTraining   ,TMVA::EventVectorOfClasses(dsi.GetNClasses() ) ) );
00657    tmpEventVector.insert( std::make_pair(Types::kTesting    ,TMVA::EventVectorOfClasses(dsi.GetNClasses() ) ) );
00658    tmpEventVector.insert( std::make_pair(Types::kMaxTreeType,TMVA::EventVectorOfClasses(dsi.GetNClasses() ) ) );
00659 
00660 
00661    // create the type, weight and boostweight branches
00662    const UInt_t nvars    = dsi.GetNVariables();
00663    const UInt_t ntgts    = dsi.GetNTargets();
00664    const UInt_t nvis     = dsi.GetNSpectators();
00665    //   std::vector<Float_t> fmlEval(nvars+ntgts+1+1+nvis);     // +1+1 for results of evaluation of cut and weight ttreeformula  
00666 
00667    // number of signal and background events passing cuts
00668    std::vector< Int_t >    nInitialEvents( dsi.GetNClasses() );
00669    std::vector< Int_t >    nEvBeforeCut(   dsi.GetNClasses() );
00670    std::vector< Int_t >    nEvAfterCut(    dsi.GetNClasses() );
00671    std::vector< Float_t >  nWeEvBeforeCut( dsi.GetNClasses() );
00672    std::vector< Float_t >  nWeEvAfterCut(  dsi.GetNClasses() );
00673    std::vector< Double_t > nNegWeights(    dsi.GetNClasses() );
00674    std::vector< Float_t* > varAvLength(    dsi.GetNClasses() );
00675 
00676    Bool_t haveArrayVariable = kFALSE;
00677    Bool_t *varIsArray = new Bool_t[nvars];
00678 
00679    for (size_t i=0; i<varAvLength.size(); i++) {
00680       varAvLength[i] = new Float_t[nvars];
00681       for (UInt_t ivar=0; ivar<nvars; ivar++) {
00682          //varIsArray[ivar] = kFALSE;
00683          varAvLength[i][ivar] = 0;
00684       }
00685    }
00686 
00687    // if we work with chains we need to remember the current tree
00688    // if the chain jumps to a new tree we have to reset the formulas
00689    for (UInt_t cl=0; cl<dsi.GetNClasses(); cl++) {
00690 
00691       Log() << kINFO << "Create training and testing trees -- looping over class \"" 
00692             << dsi.GetClassInfo(cl)->GetName() << "\" ..." << Endl;
00693 
00694       // info output for weights
00695       const TString tmpWeight = dsi.GetClassInfo(cl)->GetWeight();
00696       if (tmpWeight!="") {
00697          Log() << kINFO << "Weight expression for class \"" << dsi.GetClassInfo(cl)->GetName() << "\": \""
00698                << tmpWeight << "\"" << Endl; 
00699       }
00700       else {
00701          Log() << kINFO << "No weight expression defined for class \"" << dsi.GetClassInfo(cl)->GetName() 
00702                << "\"" << Endl; 
00703       }
00704       
00705       // used for chains only
00706       TString currentFileName("");
00707       
00708       std::vector<TreeInfo>::const_iterator treeIt(dataInput.begin(dsi.GetClassInfo(cl)->GetName()));
00709       for (;treeIt!=dataInput.end(dsi.GetClassInfo(cl)->GetName()); treeIt++) {
00710 
00711          // read first the variables
00712          std::vector<Float_t> vars(nvars);
00713          std::vector<Float_t> tgts(ntgts);
00714          std::vector<Float_t> vis(nvis);
00715          TreeInfo currentInfo = *treeIt;
00716          
00717          Bool_t isChain = (TString("TChain") == currentInfo.GetTree()->ClassName());
00718          currentInfo.GetTree()->LoadTree(0);
00719          ChangeToNewTree( currentInfo, dsi );
00720 
00721          // count number of events in tree before cut
00722          nInitialEvents.at(cl) += currentInfo.GetTree()->GetEntries();
00723          
00724 //          std::vector< std::pair< Long64_t, Types::ETreeType > >& userEvType = userDefinedEventTypes.at(cl);
00725 //          if (userEvType.size() == 0 || userEvType.back().second != currentInfo.GetTreeType()) {
00726 //             userEvType.push_back( std::make_pair< Long64_t, Types::ETreeType >(tmpEventVector.at(cl).size(), currentInfo.GetTreeType()) );
00727 //          }
00728 
00729          // loop over events in ntuple
00730          for (Long64_t evtIdx = 0; evtIdx < currentInfo.GetTree()->GetEntries(); evtIdx++) {
00731             currentInfo.GetTree()->LoadTree(evtIdx);
00732             
00733             // may need to reload tree in case of chains
00734             if (isChain) {
00735                if (currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName() != currentFileName) {
00736                   currentFileName = currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName();
00737                   ChangeToNewTree( currentInfo, dsi );
00738                }
00739             }
00740             currentInfo.GetTree()->GetEntry(evtIdx);
00741             Int_t sizeOfArrays = 1;
00742             Int_t prevArrExpr = 0;
00743             
00744             // ======= evaluate all formulas =================
00745 
00746             // first we check if some of the formulas are arrays
00747             for (UInt_t ivar=0; ivar<nvars; ivar++) {
00748                Int_t ndata = fInputFormulas[ivar]->GetNdata();
00749                varAvLength[cl][ivar] += ndata;
00750                if (ndata == 1) continue;
00751                haveArrayVariable = kTRUE;
00752                varIsArray[ivar] = kTRUE;
00753                if (sizeOfArrays == 1) {
00754                   sizeOfArrays = ndata;
00755                   prevArrExpr = ivar;
00756                } 
00757                else if (sizeOfArrays!=ndata) {
00758                   Log() << kERROR << "ERROR while preparing training and testing trees:" << Endl;
00759                   Log() << "   multiple array-type expressions of different length were encountered" << Endl;
00760                   Log() << "   location of error: event " << evtIdx 
00761                         << " in tree " << currentInfo.GetTree()->GetName()
00762                         << " of file " << currentInfo.GetTree()->GetCurrentFile()->GetName() << Endl;
00763                   Log() << "   expression " << fInputFormulas[ivar]->GetTitle() << " has " 
00764                         << ndata << " entries, while" << Endl;
00765                   Log() << "   expression " << fInputFormulas[prevArrExpr]->GetTitle() << " has "
00766                         << fInputFormulas[prevArrExpr]->GetNdata() << " entries" << Endl;
00767                   Log() << kFATAL << "Need to abort" << Endl;
00768                }
00769             }
00770 
00771             // now we read the information
00772             for (Int_t idata = 0;  idata<sizeOfArrays; idata++) {
00773                Bool_t containsNaN = kFALSE;
00774 
00775                TTreeFormula* formula = 0;
00776 
00777                // the cut expression
00778                Float_t cutVal = 1;
00779                formula = fCutFormulas[cl];
00780                if (formula) {
00781                   Int_t ndata = formula->GetNdata();
00782                   cutVal = (ndata==1 ? 
00783                             formula->EvalInstance(0) :
00784                             formula->EvalInstance(idata));
00785                   if (TMath::IsNaN(cutVal)) {
00786                      containsNaN = kTRUE;
00787                      Log() << kWARNING << "Cut expression resolves to infinite value (NaN): " 
00788                            << formula->GetTitle() << Endl;
00789                   }
00790                }
00791                
00792                // the input variable
00793                for (UInt_t ivar=0; ivar<nvars; ivar++) {
00794                   formula = fInputFormulas[ivar];
00795                   Int_t ndata = formula->GetNdata();               
00796                   vars[ivar] = (ndata == 1 ? 
00797                                 formula->EvalInstance(0) : 
00798                                 formula->EvalInstance(idata));
00799                   if (TMath::IsNaN(vars[ivar])) {
00800                      containsNaN = kTRUE;
00801                      Log() << kWARNING << "Input expression resolves to infinite value (NaN): " 
00802                            << formula->GetTitle() << Endl;
00803                   }
00804                }
00805 
00806                // the targets
00807                for (UInt_t itrgt=0; itrgt<ntgts; itrgt++) {
00808                   formula = fTargetFormulas[itrgt];
00809                   Int_t ndata = formula->GetNdata();               
00810                   tgts[itrgt] = (ndata == 1 ? 
00811                                  formula->EvalInstance(0) : 
00812                                  formula->EvalInstance(idata));
00813                   if (TMath::IsNaN(tgts[itrgt])) {
00814                      containsNaN = kTRUE;
00815                      Log() << kWARNING << "Target expression resolves to infinite value (NaN): " 
00816                            << formula->GetTitle() << Endl;
00817                   }
00818                }
00819 
00820                // the spectators
00821                for (UInt_t itVis=0; itVis<nvis; itVis++) {
00822                   formula = fSpectatorFormulas[itVis];
00823                   Int_t ndata = formula->GetNdata();               
00824                   vis[itVis] = (ndata == 1 ? 
00825                                 formula->EvalInstance(0) : 
00826                                 formula->EvalInstance(idata));
00827                   if (TMath::IsNaN(vis[itVis])) {
00828                      containsNaN = kTRUE;
00829                      Log() << kWARNING << "Spectator expression resolves to infinite value (NaN): " 
00830                            << formula->GetTitle() << Endl;
00831                   }
00832                }
00833 
00834 
00835                // the weight
00836                Float_t weight = currentInfo.GetWeight(); // multiply by tree weight
00837                formula = fWeightFormula[cl];
00838                if (formula!=0) {
00839                   Int_t ndata = formula->GetNdata();
00840                   weight *= (ndata == 1 ?
00841                              formula->EvalInstance() :
00842                              formula->EvalInstance(idata));
00843                   if (TMath::IsNaN(weight)) {
00844                      containsNaN = kTRUE;
00845                      Log() << kWARNING << "Weight expression resolves to infinite value (NaN): " 
00846                            << formula->GetTitle() << Endl;
00847                   }
00848                }
00849             
00850                // Count the events before rejection due to cut or NaN value
00851                // (weighted and unweighted)
00852                nEvBeforeCut.at(cl) ++;
00853                if (!TMath::IsNaN(weight))
00854                   nWeEvBeforeCut.at(cl) += weight;
00855 
00856                // apply the cut
00857                // skip rest if cut is not fulfilled
00858                if (cutVal<0.5) continue;
00859 
00860                // global flag if negative weights exist -> can be used by classifiers who may 
00861                // require special data treatment (also print warning)
00862                if (weight < 0) nNegWeights.at(cl)++;
00863 
00864                // now read the event-values (variables and regression targets)
00865 
00866                if (containsNaN) {
00867                   Log() << kWARNING << "Event " << evtIdx;
00868                   if (sizeOfArrays>1) Log() << kWARNING << " rejected" << Endl;
00869                   continue;
00870                }
00871 
00872                // Count the events after rejection due to cut or NaN value
00873                // (weighted and unweighted)
00874                nEvAfterCut.at(cl) ++;
00875                nWeEvAfterCut.at(cl) += weight;
00876 
00877                // event accepted, fill temporary ntuple
00878                tmpEventVector.find(currentInfo.GetTreeType())->second.at(cl).push_back(new Event(vars, tgts , vis, cl , weight));
00879 
00880             }
00881          }
00882          
00883          currentInfo.GetTree()->ResetBranchAddresses();
00884       }
00885 
00886 //       // compute renormalisation factors
00887 //       renormFactor.at(cl) = nTempEvents.at(cl)/sumOfWeights.at(cl); --> will be done in dedicated member function
00888    }
00889 
00890    // for output, check maximum class name length
00891    Int_t maxL = dsi.GetClassNameMaxLength();
00892    
00893    Log() << kINFO << "Number of events in input trees (after possible flattening of arrays):" << Endl;
00894    for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00895       Log() << kINFO << "    " 
00896             << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00897             << "      -- number of events       : "
00898             << std::setw(5) << nEvBeforeCut.at(cl) 
00899             << "  / sum of weights: " << std::setw(5) << nWeEvBeforeCut.at(cl) << Endl;
00900    }
00901 
00902    for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00903       Log() << kINFO << "    " << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00904             <<" tree -- total number of entries: " 
00905             << std::setw(5) << dataInput.GetEntries(dsi.GetClassInfo(cl)->GetName()) << Endl;
00906    }
00907 
00908    Log() << kINFO << "Preselection:" << Endl;
00909    if (dsi.HasCuts()) {
00910       for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
00911          Log() << kINFO << "    " << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00912                << " requirement: \"" << dsi.GetClassInfo(cl)->GetCut() << "\"" << Endl;
00913          Log() << kINFO << "    " 
00914                << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00915                << "      -- number of events passed: "
00916                << std::setw(5) << nEvAfterCut.at(cl)
00917                << "  / sum of weights: " << std::setw(5) << nWeEvAfterCut.at(cl) << Endl;
00918          Log() << kINFO << "    " 
00919                << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName() 
00920                << "      -- efficiency             : "
00921                << std::setw(6) << nWeEvAfterCut.at(cl)/nWeEvBeforeCut.at(cl) << Endl;
00922       }
00923    }
00924    else Log() << kINFO << "    No preselection cuts applied on event classes" << Endl;
00925 
00926    delete[] varIsArray;
00927    for (size_t i=0; i<varAvLength.size(); i++)
00928       delete[] varAvLength[i];
00929 
00930 }
00931 
00932 //_______________________________________________________________________
00933 TMVA::DataSet*  TMVA::DataSetFactory::MixEvents( DataSetInfo& dsi, 
00934                                                  TMVA::EventVectorOfClassesOfTreeType& tmpEventVector, 
00935                                                  TMVA::NumberPerClassOfTreeType& nTrainTestEvents,
00936                                                  const TString& splitMode,
00937                                                  const TString& mixMode, 
00938                                                  const TString& normMode, 
00939                                                  UInt_t splitSeed)
00940 {
00941    // Select and distribute unassigned events to kTraining and kTesting
00942    Bool_t emptyUndefined  = kTRUE;
00943 
00944 //    // check if the vectors of all classes are empty
00945    for( Int_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
00946       emptyUndefined &= tmpEventVector[Types::kMaxTreeType].at(cls).empty();
00947    }
00948 
00949    TMVA::RandomGenerator rndm( splitSeed );
00950    
00951    // ==== splitting of undefined events to kTraining and kTesting
00952 
00953    // if splitMode contains "RANDOM", then shuffle the undefined events
00954    if (splitMode.Contains( "RANDOM" ) && !emptyUndefined ) {
00955       Log() << kDEBUG << "randomly shuffling events which are not yet associated to testing or training"<<Endl;
00956       // random shuffle the undefined events of each class
00957       for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
00958          std::random_shuffle(tmpEventVector[Types::kMaxTreeType].at(cls).begin(), 
00959                              tmpEventVector[Types::kMaxTreeType].at(cls).end(),
00960                              rndm );
00961       }
00962    }
00963 
00964    // check for each class the number of training and testing events, the requested number and the available number
00965    Log() << kDEBUG << "SPLITTING ========" << Endl;
00966    for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
00967       Log() << kDEBUG << "---- class " << cls << Endl;
00968       Log() << kDEBUG << "check number of training/testing events, requested and available number of events and for class " << cls << Endl;
00969 
00970       // check if enough or too many events are already in the training/testing eventvectors of the class cls
00971       EventVector& eventVectorTraining = tmpEventVector.find( Types::kTraining    )->second.at(cls);
00972       EventVector& eventVectorTesting  = tmpEventVector.find( Types::kTesting     )->second.at(cls);
00973       EventVector& eventVectorUndefined= tmpEventVector.find( Types::kMaxTreeType )->second.at(cls);
00974       
00975       Int_t alreadyAvailableTraining   = eventVectorTraining.size();
00976       Int_t alreadyAvailableTesting    = eventVectorTesting.size();
00977       Int_t availableUndefined         = eventVectorUndefined.size();
00978 
00979       Int_t requestedTraining          = nTrainTestEvents.find( Types::kTraining )->second.at(cls);
00980       Int_t requestedTesting           = nTrainTestEvents.find( Types::kTesting  )->second.at(cls);
00981       
00982       Log() << kDEBUG << "availableTraining  " << alreadyAvailableTraining << Endl;
00983       Log() << kDEBUG << "availableTesting   " << alreadyAvailableTesting << Endl;
00984       Log() << kDEBUG << "availableUndefined " << availableUndefined << Endl;
00985       Log() << kDEBUG << "requestedTraining  " << requestedTraining << Endl;
00986       Log() << kDEBUG << "requestedTesting  " << requestedTesting << Endl;
00987       //
00988       // nomenclature r=available training
00989       //              s=available testing 
00990       //              u=available undefined
00991       //              R= requested training
00992       //              S= requested testing
00993       //              nR = used for selection of training events
00994       //              nS = used for selection of test events
00995       //              we have: nR + nS = r+s+u
00996       //              free events: Nfree = u-Thet(R-r)-Thet(S-s)
00997       //              nomenclature: Thet(x) = x,  if x>0 else 0;
00998       //              nR = max(R,r) + 0.5 * Nfree
00999       //              nS = max(S,s) + 0.5 * Nfree
01000       //              nR +nS = R+S + u-R+r-S+s = u+r+s= ok! for R>r
01001       //              nR +nS = r+S + u-S+s = u+r+s= ok! for r>R
01002 
01003       //EVT three different cases might occur here
01004       //
01005       // Case a
01006       // requestedTraining and requestedTesting >0 
01007       // free events: Nfree = u-Thet(R-r)-Thet(S-s)
01008       //              nR = Max(R,r) + 0.5 * Nfree
01009       //              nS = Max(S,s) + 0.5 * Nfree
01010       // 
01011       // Case b
01012       // exactly one of requestedTraining or requestedTesting >0
01013       // assume training R >0
01014       //    nR  = max(R,r) 
01015       //    nS  = s+u+r-nR
01016       //    and  s=nS
01017       //
01018       //Case c: 
01019       // requestedTraining=0, requestedTesting=0 
01020       // Nfree = u-|r-s|
01021       // if NFree >=0
01022       //    R = Max(r,s) + 0.5 * Nfree = S
01023       // else if r>s 
01024       //    R = r; S=s+u
01025       // else
01026       //    R = r+u; S=s
01027       //
01028       // Next steps:
01029       // Determination of Event numbers R,S, nR, nS
01030       // distribute undefined events according to nR, nS
01031       // finally determine actual sub samples from nR and nS to be used in training / testing
01032       //
01033       // implementation of case C)
01034       int useForTesting,useForTraining;
01035       if( (requestedTraining == 0) && (requestedTesting == 0)){ 
01036          // 0 means automatic distribution of events
01037          Log() << kDEBUG << "requested 0" << Endl;         
01038          // try to get the same number of events in training and testing for this class (balance)
01039          Int_t NFree = availableUndefined - TMath::Abs(alreadyAvailableTraining - alreadyAvailableTesting);
01040          if (NFree >=0){
01041             requestedTraining = TMath::Max(alreadyAvailableTraining,alreadyAvailableTesting) + NFree/2;
01042             requestedTesting  = availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - requestedTraining; // the rest
01043          } else if (alreadyAvailableTraining > alreadyAvailableTesting){ //r>s
01044             requestedTraining = alreadyAvailableTraining;
01045             requestedTesting  = alreadyAvailableTesting +availableUndefined;
01046          }
01047          else {
01048             requestedTraining = alreadyAvailableTraining+availableUndefined;
01049             requestedTesting  = alreadyAvailableTesting;            
01050          }
01051          useForTraining = requestedTraining; 
01052          useForTesting  = requestedTesting; 
01053       }
01054       else if ((requestedTesting == 0)){ // case B)
01055          useForTraining = TMath::Max(requestedTraining,alreadyAvailableTraining);
01056          useForTesting= availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - useForTraining; // the rest
01057          requestedTesting = useForTesting;
01058       }
01059       else if ((requestedTraining == 0)){ // case B)
01060          useForTesting = TMath::Max(requestedTesting,alreadyAvailableTesting);
01061          useForTraining= availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - useForTesting; // the rest
01062          requestedTraining = useForTraining;
01063       }
01064       else{ // case A
01065          int NFree = availableUndefined-TMath::Max(requestedTraining-alreadyAvailableTraining,0)-TMath::Max(requestedTesting-alreadyAvailableTesting,0);
01066          if (NFree <0) NFree = 0;
01067          useForTraining = TMath::Max(requestedTraining,alreadyAvailableTraining) + NFree/2;
01068          useForTesting= availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - useForTraining; // the rest
01069       }
01070       Log() << kDEBUG << "determined event sample size to select training sample from="<<useForTraining<<Endl;
01071       Log() << kDEBUG << "determined event sample size to select test sample from="<<useForTesting<<Endl;
01072       
01073 
01074       // associate undefined events 
01075       if( splitMode == "ALTERNATE" ){
01076          Log() << kDEBUG << "split 'ALTERNATE'" << Endl;
01077          Int_t nTraining = alreadyAvailableTraining;
01078          Int_t nTesting  = alreadyAvailableTesting;
01079          for( EventVector::iterator it = eventVectorUndefined.begin(), itEnd = eventVectorUndefined.end(); it != itEnd; ){
01080             ++nTraining;
01081             if( nTraining <= requestedTraining ){
01082                eventVectorTraining.insert( eventVectorTraining.end(), (*it) );
01083                ++it;
01084             }
01085             if( it != itEnd ){
01086                ++nTesting;
01087                eventVectorTesting.insert( eventVectorTesting.end(), (*it) );
01088                ++it;
01089             }
01090          }
01091       }else{
01092          Log() << kDEBUG << "split '" << splitMode << "'" << Endl;
01093 
01094          // test if enough events are available
01095          Log() << kDEBUG << "availableundefined : " << availableUndefined << Endl;
01096          Log() << kDEBUG << "useForTraining     : " << useForTraining << Endl;
01097          Log() << kDEBUG << "useForTesting      : " << useForTesting  << Endl;
01098          Log() << kDEBUG << "alreadyAvailableTraining      : " << alreadyAvailableTraining  << Endl;
01099          Log() << kDEBUG << "alreadyAvailableTesting       : " << alreadyAvailableTesting  << Endl;
01100 
01101          if( availableUndefined<(useForTraining-alreadyAvailableTraining) ||
01102              availableUndefined<(useForTesting -alreadyAvailableTesting ) || 
01103              availableUndefined<(useForTraining+useForTesting-alreadyAvailableTraining-alreadyAvailableTesting ) ){
01104             Log() << kFATAL << "More events requested than available!" << Endl;
01105          }
01106 
01107          // select the events
01108          if (useForTraining>alreadyAvailableTraining){
01109             eventVectorTraining.insert(  eventVectorTraining.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTraining- alreadyAvailableTraining );
01110             eventVectorUndefined.erase( eventVectorUndefined.begin(), eventVectorUndefined.begin() + useForTraining- alreadyAvailableTraining);
01111          }
01112          if (useForTesting>alreadyAvailableTesting){
01113             eventVectorTesting.insert(  eventVectorTesting.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTesting- alreadyAvailableTesting );
01114          }
01115       }
01116       eventVectorUndefined.clear();      
01117       // finally shorten the event vectors to the requested size by removing random events
01118       if (splitMode.Contains( "RANDOM" )){
01119          UInt_t sizeTraining  = eventVectorTraining.size();
01120          if( sizeTraining > UInt_t(requestedTraining) ){
01121            std::vector<UInt_t> indicesTraining( sizeTraining );
01122             // make indices
01123             std::generate( indicesTraining.begin(), indicesTraining.end(), TMVA::Increment<UInt_t>(0) );
01124             // shuffle indices
01125             std::random_shuffle( indicesTraining.begin(), indicesTraining.end(), rndm );
01126             // erase indices of not needed events
01127             indicesTraining.erase( indicesTraining.begin()+sizeTraining-UInt_t(requestedTraining), indicesTraining.end() );
01128             // delete all events with the given indices
01129             for( std::vector<UInt_t>::iterator it = indicesTraining.begin(), itEnd = indicesTraining.end(); it != itEnd; ++it ){
01130                delete eventVectorTraining.at( (*it) ); // delete event
01131                eventVectorTraining.at( (*it) ) = NULL; // set pointer to NULL
01132             }
01133             // now remove and erase all events with pointer==NULL
01134             eventVectorTraining.erase( std::remove( eventVectorTraining.begin(), eventVectorTraining.end(), (void*)NULL ), eventVectorTraining.end() );
01135          }
01136 
01137          UInt_t sizeTesting   = eventVectorTesting.size();
01138          if( sizeTesting > UInt_t(requestedTesting) ){
01139             std::vector<UInt_t> indicesTesting( sizeTesting );
01140             // make indices
01141             std::generate( indicesTesting.begin(), indicesTesting.end(), TMVA::Increment<UInt_t>(0) );
01142             // shuffle indices
01143             std::random_shuffle( indicesTesting.begin(), indicesTesting.end(), rndm );
01144             // erase indices of not needed events
01145             indicesTesting.erase( indicesTesting.begin()+sizeTesting-UInt_t(requestedTesting), indicesTesting.end() );
01146             // delete all events with the given indices
01147             for( std::vector<UInt_t>::iterator it = indicesTesting.begin(), itEnd = indicesTesting.end(); it != itEnd; ++it ){
01148                delete eventVectorTesting.at( (*it) ); // delete event
01149                eventVectorTesting.at( (*it) ) = NULL; // set pointer to NULL
01150             }
01151             // now remove and erase all events with pointer==NULL
01152             eventVectorTesting.erase( std::remove( eventVectorTesting.begin(), eventVectorTesting.end(), (void*)NULL ), eventVectorTesting.end() );
01153          }
01154       }
01155       else { // erase at end
01156          if( eventVectorTraining.size() < UInt_t(requestedTraining) )
01157             Log() << kWARNING << "DataSetFactory/requested number of training samples larger than size of eventVectorTraining.\n"
01158                   << "There is probably an issue. Please contact the TMVA developers." << Endl;
01159          std::for_each( eventVectorTraining.begin()+requestedTraining, eventVectorTraining.end(), DeleteFunctor<Event>() );
01160          eventVectorTraining.erase(eventVectorTraining.begin()+requestedTraining,eventVectorTraining.end());
01161 
01162          if( eventVectorTesting.size() < UInt_t(requestedTesting) )
01163             Log() << kWARNING << "DataSetFactory/requested number of testing samples larger than size of eventVectorTesting.\n"
01164                   << "There is probably an issue. Please contact the TMVA developers." << Endl;
01165          std::for_each( eventVectorTesting.begin()+requestedTesting, eventVectorTesting.end(), DeleteFunctor<Event>() );
01166          eventVectorTesting.erase(eventVectorTesting.begin()+requestedTesting,eventVectorTesting.end());
01167       }
01168    }
01169 
01170    TMVA::DataSetFactory::RenormEvents( dsi, tmpEventVector, normMode );
01171 
01172    Int_t trainingSize = 0;
01173    Int_t testingSize  = 0;
01174 
01175    // sum up number of training and testing events
01176    for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
01177       trainingSize += tmpEventVector[Types::kTraining].at(cls).size();
01178       testingSize  += tmpEventVector[Types::kTesting].at(cls).size();
01179    }
01180 
01181    // --- collect all training (testing) events into the training (testing) eventvector
01182 
01183    // create event vectors reserve enough space
01184    EventVector* trainingEventVector = new EventVector();
01185    EventVector* testingEventVector  = new EventVector();
01186 
01187    trainingEventVector->reserve( trainingSize );
01188    testingEventVector->reserve( testingSize );
01189 
01190 
01191    // collect the events
01192 
01193    // mixing of kTraining and kTesting data sets
01194    Log() << kDEBUG << " MIXING ============= " << Endl;
01195 
01196    if( mixMode == "ALTERNATE" ){
01197       // Inform user if he tries to use alternate mixmode for 
01198       // event classes with different number of events, this works but the alternation stops at the last event of the smaller class
01199       for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
01200          if (tmpEventVector[Types::kTraining].at(cls).size() != tmpEventVector[Types::kTraining].at(0).size()){
01201             Log() << kINFO << "Training sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
01202          }
01203          if (tmpEventVector[Types::kTesting].at(cls).size() != tmpEventVector[Types::kTesting].at(0).size()){
01204             Log() << kINFO << "Testing sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
01205          }
01206       }
01207       typedef EventVector::iterator EvtVecIt;
01208       EvtVecIt itEvent, itEventEnd;
01209 
01210       // insert first class
01211       Log() << kDEBUG << "insert class 0 into training and test vector" << Endl;
01212       trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(0).begin(), tmpEventVector[Types::kTraining].at(0).end() );
01213       testingEventVector->insert( testingEventVector->end(),   tmpEventVector[Types::kTesting].at(0).begin(),  tmpEventVector[Types::kTesting].at(0).end() );
01214       
01215       // insert other classes
01216       EvtVecIt itTarget;
01217       for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
01218          Log() << kDEBUG << "insert class " << cls << Endl;
01219          // training vector
01220          itTarget = trainingEventVector->begin() - 1; // start one before begin
01221          // loop over source 
01222          for( itEvent = tmpEventVector[Types::kTraining].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTraining].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
01223 //            if( std::distance( itTarget, trainingEventVector->end()) < Int_t(cls+1) ) {
01224             if( (trainingEventVector->end() - itTarget) < Int_t(cls+1) ) {
01225                itTarget = trainingEventVector->end();
01226                trainingEventVector->insert( itTarget, itEvent, itEventEnd ); // fill in the rest without mixing
01227                break;
01228             }else{ 
01229                itTarget += cls+1;
01230                trainingEventVector->insert( itTarget, (*itEvent) ); // fill event
01231             }
01232          }
01233          // testing vector
01234          itTarget = testingEventVector->begin() - 1;
01235          // loop over source 
01236          for( itEvent = tmpEventVector[Types::kTesting].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTesting].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
01237 //             if( std::distance( itTarget, testingEventVector->end()) < Int_t(cls+1) ) {
01238             if( ( testingEventVector->end() - itTarget ) < Int_t(cls+1) ) {
01239                itTarget = testingEventVector->end();
01240                testingEventVector->insert( itTarget, itEvent, itEventEnd ); // fill in the rest without mixing
01241                break;
01242             }else{ 
01243                itTarget += cls+1;
01244                testingEventVector->insert( itTarget, (*itEvent) ); // fill event
01245             }
01246          }
01247       }
01248 
01249       // debugging output: classnumbers of all events in training and testing vectors
01250       //       std::cout << std::endl;
01251       //       std::cout << "TRAINING VECTOR" << std::endl;
01252       //       std::transform( trainingEventVector->begin(), trainingEventVector->end(), ostream_iterator<Int_t>(std::cout, "|"), std::mem_fun(&TMVA::Event::GetClass) );
01253       
01254       //       std::cout << std::endl;
01255       //       std::cout << "TESTING VECTOR" << std::endl;
01256       //       std::transform( testingEventVector->begin(), testingEventVector->end(), ostream_iterator<Int_t>(std::cout, "|"), std::mem_fun(&TMVA::Event::GetClass) );
01257       //       std::cout << std::endl;
01258 
01259    }else{ 
01260       for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
01261          trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(cls).begin(), tmpEventVector[Types::kTraining].at(cls).end() );
01262          testingEventVector->insert ( testingEventVector->end(),  tmpEventVector[Types::kTesting].at(cls).begin(),  tmpEventVector[Types::kTesting].at(cls).end()  );
01263       }
01264    }
01265 
01266    //    std::cout << "trainingEventVector " << trainingEventVector->size() << std::endl;
01267    //    std::cout << "testingEventVector  " << testingEventVector->size() << std::endl;
01268 
01269    //    std::transform( trainingEventVector->begin(), trainingEventVector->end(), ostream_iterator<Int_t>(std::cout, "> \n"), std::mem_fun(&TMVA::Event::GetNVariables) );
01270    //    std::transform( testingEventVector->begin(), testingEventVector->end(), ostream_iterator<Int_t>(std::cout, "> \n"), std::mem_fun(&TMVA::Event::GetNVariables) );
01271 
01272    // delete the tmpEventVector (but not the events therein)
01273    tmpEventVector[Types::kTraining].clear();
01274    tmpEventVector[Types::kTesting].clear();
01275 
01276    tmpEventVector[Types::kMaxTreeType].clear();
01277 
01278    if (mixMode == "RANDOM") {
01279       Log() << kDEBUG << "shuffling events"<<Endl;
01280 
01281       //       std::cout << "before" << std::endl;
01282       //       std::for_each( trainingEventVector->begin(), trainingEventVector->begin()+10, std::bind2nd(std::mem_fun(&TMVA::Event::Print),std::cout) );
01283       
01284       std::random_shuffle( trainingEventVector->begin(), trainingEventVector->end(), rndm );
01285       std::random_shuffle( testingEventVector->begin(),  testingEventVector->end(),  rndm  );
01286 
01287       //       std::cout << "after" << std::endl;
01288       //       std::for_each( trainingEventVector->begin(), trainingEventVector->begin()+10, std::bind2nd(std::mem_fun(&TMVA::Event::Print),std::cout) );
01289    }
01290 
01291    Log() << kDEBUG << "trainingEventVector " << trainingEventVector->size() << Endl;
01292    Log() << kDEBUG << "testingEventVector  " << testingEventVector->size() << Endl;
01293 
01294    // create dataset
01295    DataSet* ds = new DataSet(dsi);
01296 
01297    Log() << kINFO << "Create internal training tree" << Endl;        
01298    ds->SetEventCollection(trainingEventVector, Types::kTraining ); 
01299    Log() << kINFO << "Create internal testing tree" << Endl;        
01300    ds->SetEventCollection(testingEventVector,  Types::kTesting  ); 
01301 
01302 
01303    return ds;
01304    
01305 }
01306 
01307 
01308 
01309 //_______________________________________________________________________
01310 void  TMVA::DataSetFactory::RenormEvents( TMVA::DataSetInfo& dsi, 
01311                                           TMVA::EventVectorOfClassesOfTreeType& tmpEventVector, 
01312                                           const TString&        normMode )
01313 {
01314    // ============================================================
01315    // renormalisation
01316    // ============================================================
01317 
01318 
01319 
01320    // print rescaling info
01321    if (normMode == "NONE") {
01322       Log() << kINFO << "No weight renormalisation applied: use original event weights" << Endl;
01323       return;
01324    }
01325 
01326    // ---------------------------------
01327    // compute sizes and sums of weights
01328    Int_t trainingSize = 0;
01329    Int_t testingSize  = 0;
01330 
01331    ValuePerClass trainingSumWeightsPerClass( dsi.GetNClasses() );
01332    ValuePerClass testingSumWeightsPerClass( dsi.GetNClasses() );
01333 
01334    NumberPerClass trainingSizePerClass( dsi.GetNClasses() );
01335    NumberPerClass testingSizePerClass( dsi.GetNClasses() );
01336 
01337    Double_t trainingSumWeights = 0;
01338    Double_t testingSumWeights  = 0;
01339 
01340    for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
01341       trainingSizePerClass.at(cls) = tmpEventVector[Types::kTraining].at(cls).size();
01342       testingSizePerClass.at(cls)  = tmpEventVector[Types::kTesting].at(cls).size();
01343 
01344       trainingSize += trainingSizePerClass.back();
01345       testingSize  += testingSizePerClass.back();
01346 
01347       // the functional solution
01348       // sum up the weights in Double_t although the individual weights are Float_t to prevent rounding issues in addition of floating points
01349       //
01350       // accumulate --> does what the name says
01351       //     begin() and end() denote the range of the vector to be accumulated
01352       //     Double_t(0) tells accumulate the type and the starting value
01353       //     compose_binary creates a BinaryFunction of ...
01354       //         std::plus<Double_t>() knows how to sum up two doubles
01355       //         null<Double_t>() leaves the first argument (the running sum) unchanged and returns it
01356       //         std::mem_fun knows how to call a member function (type and member-function given as argument) and return the result
01357       //
01358       // all together sums up all the event-weights of the events in the vector and returns it
01359       trainingSumWeightsPerClass.at(cls) = std::accumulate( tmpEventVector[Types::kTraining].at(cls).begin(),
01360                                                             tmpEventVector[Types::kTraining].at(cls).end(),
01361                                                             Double_t(0),
01362                                                             compose_binary( std::plus<Double_t>(),
01363                                                                             null<Double_t>(),
01364                                                                             std::mem_fun(&TMVA::Event::GetOriginalWeight) ) );
01365 
01366       testingSumWeightsPerClass.at(cls)  = std::accumulate( tmpEventVector[Types::kTesting].at(cls).begin(),
01367                                                             tmpEventVector[Types::kTesting].at(cls).end(),
01368                                                             Double_t(0),
01369                                                             compose_binary( std::plus<Double_t>(),
01370                                                                             null<Double_t>(),
01371                                                                             std::mem_fun(&TMVA::Event::GetOriginalWeight) ) );
01372 
01373 
01374       trainingSumWeights += trainingSumWeightsPerClass.at(cls);
01375       testingSumWeights  += testingSumWeightsPerClass.at(cls);
01376    }
01377 
01378    // ---------------------------------
01379    // compute renormalization factors
01380 
01381    ValuePerClass renormFactor( dsi.GetNClasses() );
01382 
01383    if (normMode == "NUMEVENTS") {
01384       Log() << kINFO << "Weight renormalisation mode: \"NumEvents\": renormalise independently the ..." << Endl;
01385       Log() << kINFO << "... class weights so that Sum[i=1..N_j]{w_i} = N_j, j=0,1,2..." << Endl;
01386       Log() << kINFO << "... (note that N_j is the sum of training and test events)" << Endl;
01387 
01388       for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
01389          renormFactor.at(cls) = ( (trainingSizePerClass.at(cls) + testingSizePerClass.at(cls))/
01390                                   (trainingSumWeightsPerClass.at(cls) + testingSumWeightsPerClass.at(cls)) );
01391       }
01392    }
01393    else if (normMode == "EQUALNUMEVENTS") {
01394       Log() << kINFO << "Weight renormalisation mode: \"EqualNumEvents\": renormalise class weights ..." << Endl;
01395       Log() << kINFO << "... so that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ..." << Endl;
01396       Log() << kINFO << "... (note that N_j is the sum of training and test events)" << Endl;
01397 
01398       for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ) {
01399          renormFactor.at(cls) = Float_t(trainingSizePerClass.at(cls)+testingSizePerClass.at(cls))/
01400             (trainingSumWeightsPerClass.at(cls)+testingSumWeightsPerClass.at(cls));
01401       }
01402       // normalize to size of first class
01403       UInt_t referenceClass = 0;
01404       for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ) {
01405          if( cls == referenceClass ) continue;
01406          renormFactor.at(cls) *= Float_t(trainingSizePerClass.at(referenceClass)+testingSizePerClass.at(referenceClass) )/
01407             Float_t( trainingSizePerClass.at(cls)+testingSizePerClass.at(cls) );
01408       }
01409    }
01410    else {
01411       Log() << kFATAL << "<PrepareForTrainingAndTesting> Unknown NormMode: " << normMode << Endl;
01412    }
01413 
01414    // ---------------------------------
01415    // now apply the normalization factors
01416    Int_t maxL = dsi.GetClassNameMaxLength();
01417    for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls<clsEnd; ++cls) { 
01418       Log() << kINFO << "--> Rescale " << setiosflags(ios::left) << std::setw(maxL) 
01419             << dsi.GetClassInfo(cls)->GetName() << " event weights by factor: " << renormFactor.at(cls) << Endl;
01420       std::for_each( tmpEventVector[Types::kTraining].at(cls).begin(), 
01421                      tmpEventVector[Types::kTraining].at(cls).end(),
01422                      std::bind2nd(std::mem_fun(&TMVA::Event::ScaleWeight),renormFactor.at(cls)) );
01423       std::for_each( tmpEventVector[Types::kTesting].at(cls).begin(), 
01424                      tmpEventVector[Types::kTesting].at(cls).end(),
01425                      std::bind2nd(std::mem_fun(&TMVA::Event::ScaleWeight),renormFactor.at(cls)) );
01426    }
01427 
01428 
01429 
01430       
01431    // ---------------------------------
01432    // for information purposes
01433    dsi.SetNormalization( normMode );
01434 
01435    // ============================
01436    // print out the result
01437    // (same code as before --> this can be done nicer )
01438    //
01439 
01440    Log() << kINFO << "Number of training and testing events after rescaling:" << Endl;
01441    Log() << kINFO << "------------------------------------------------------" << Endl;
01442    trainingSumWeights = 0;
01443    testingSumWeights  = 0;
01444    for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
01445 
01446       trainingSumWeightsPerClass.at(cls) = (std::accumulate( tmpEventVector[Types::kTraining].at(cls).begin(),  // accumulate --> start at begin
01447                                                              tmpEventVector[Types::kTraining].at(cls).end(),    //    until end()
01448                                                              Double_t(0),                                       // values are of type double
01449                                                              compose_binary( std::plus<Double_t>(),             // define addition for doubles
01450                                                                              null<Double_t>(),                  // take the argument, don't do anything and return it
01451                                                                              std::mem_fun(&TMVA::Event::GetOriginalWeight) ) )); // take the value from GetOriginalWeight
01452 
01453       testingSumWeightsPerClass.at(cls)  = std::accumulate( tmpEventVector[Types::kTesting].at(cls).begin(),
01454                                                             tmpEventVector[Types::kTesting].at(cls).end(),
01455                                                             Double_t(0),
01456                                                             compose_binary( std::plus<Double_t>(),
01457                                                                             null<Double_t>(),
01458                                                                             std::mem_fun(&TMVA::Event::GetOriginalWeight) ) );
01459 
01460 
01461       trainingSumWeights += trainingSumWeightsPerClass.at(cls);
01462       testingSumWeights  += testingSumWeightsPerClass.at(cls);
01463 
01464       // output statistics
01465       Log() << kINFO << setiosflags(ios::left) << std::setw(maxL) 
01466             << dsi.GetClassInfo(cls)->GetName() << " -- " 
01467             << "training entries            : " << trainingSizePerClass.at(cls) 
01468             <<  " (" << "sum of weights: " << trainingSumWeightsPerClass.at(cls) << ")" << Endl;
01469       Log() << kINFO << setiosflags(ios::left) << std::setw(maxL) 
01470             << dsi.GetClassInfo(cls)->GetName() << " -- " 
01471             << "testing entries             : " << testingSizePerClass.at(cls) 
01472             <<  " (" << "sum of weights: " << testingSumWeightsPerClass.at(cls) << ")" << Endl;
01473       Log() << kINFO << setiosflags(ios::left) << std::setw(maxL) 
01474             << dsi.GetClassInfo(cls)->GetName() << " -- " 
01475             << "training and testing entries: " 
01476             << (trainingSizePerClass.at(cls)+testingSizePerClass.at(cls)) 
01477             << " (" << "sum of weights: " 
01478             << (trainingSumWeightsPerClass.at(cls)+testingSumWeightsPerClass.at(cls)) << ")" << Endl;
01479    }
01480 
01481 }
01482 
01483 
01484 

Generated on Tue Jul 5 15:16:45 2011 for ROOT_528-00b_version by  doxygen 1.5.1