00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029 #include <assert.h>
00030
00031 #include <map>
00032 #include <vector>
00033 #include <iomanip>
00034 #include <iostream>
00035
00036 #include <algorithm>
00037 #include <functional>
00038 #include <numeric>
00039
00040 #include "TMVA/DataSetFactory.h"
00041
00042 #include "TEventList.h"
00043 #include "TFile.h"
00044 #include "TH1.h"
00045 #include "TH2.h"
00046 #include "TProfile.h"
00047 #include "TRandom3.h"
00048 #include "TMatrixF.h"
00049 #include "TVectorF.h"
00050 #include "TMath.h"
00051 #include "TROOT.h"
00052
00053 #ifndef ROOT_TMVA_MsgLogger
00054 #include "TMVA/MsgLogger.h"
00055 #endif
00056 #ifndef ROOT_TMVA_Configurable
00057 #include "TMVA/Configurable.h"
00058 #endif
00059 #ifndef ROOT_TMVA_VariableIdentityTransform
00060 #include "TMVA/VariableIdentityTransform.h"
00061 #endif
00062 #ifndef ROOT_TMVA_VariableDecorrTransform
00063 #include "TMVA/VariableDecorrTransform.h"
00064 #endif
00065 #ifndef ROOT_TMVA_VariablePCATransform
00066 #include "TMVA/VariablePCATransform.h"
00067 #endif
00068 #ifndef ROOT_TMVA_DataSet
00069 #include "TMVA/DataSet.h"
00070 #endif
00071 #ifndef ROOT_TMVA_DataSetInfo
00072 #include "TMVA/DataSetInfo.h"
00073 #endif
00074 #ifndef ROOT_TMVA_DataInputHandler
00075 #include "TMVA/DataInputHandler.h"
00076 #endif
00077 #ifndef ROOT_TMVA_Event
00078 #include "TMVA/Event.h"
00079 #endif
00080
00081 using namespace std;
00082
00083 TMVA::DataSetFactory* TMVA::DataSetFactory::fgInstance = 0;
00084
00085 namespace TMVA {
00086
00087
00088 Int_t LargestCommonDivider(Int_t a, Int_t b)
00089 {
00090 if (a<b) {Int_t tmp = a; a=b; b=tmp; }
00091 if (b==0) return a;
00092 Int_t fullFits = a/b;
00093 return LargestCommonDivider(b,a-b*fullFits);
00094 }
00095 }
00096
00097
00098 TMVA::DataSetFactory::DataSetFactory() :
00099 fVerbose(kFALSE),
00100 fVerboseLevel(TString("Info")),
00101 fCurrentTree(0),
00102 fCurrentEvtIdx(0),
00103 fInputFormulas(0),
00104 fLogger( new MsgLogger("DataSetFactory", kINFO) )
00105 {
00106
00107 }
00108
00109
00110 TMVA::DataSetFactory::~DataSetFactory()
00111 {
00112
00113 std::vector<TTreeFormula*>::const_iterator formIt;
00114
00115 for (formIt = fInputFormulas.begin() ; formIt!=fInputFormulas.end() ; formIt++) if (*formIt) delete *formIt;
00116 for (formIt = fTargetFormulas.begin() ; formIt!=fTargetFormulas.end() ; formIt++) if (*formIt) delete *formIt;
00117 for (formIt = fCutFormulas.begin() ; formIt!=fCutFormulas.end() ; formIt++) if (*formIt) delete *formIt;
00118 for (formIt = fWeightFormula.begin() ; formIt!=fWeightFormula.end() ; formIt++) if (*formIt) delete *formIt;
00119 for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); formIt++) if (*formIt) delete *formIt;
00120
00121 delete fLogger;
00122 }
00123
00124
00125 TMVA::DataSet* TMVA::DataSetFactory::CreateDataSet( TMVA::DataSetInfo& dsi, TMVA::DataInputHandler& dataInput )
00126 {
00127
00128
00129
00130 DataSet * ds = BuildInitialDataSet( dsi, dataInput );
00131
00132 if (ds->GetNEvents() > 1) {
00133 CalcMinMax(ds,dsi);
00134
00135
00136 for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
00137 const TString className = dsi.GetClassInfo(cl)->GetName();
00138 dsi.SetCorrelationMatrix( className, CalcCorrelationMatrix( ds, cl ) );
00139 dsi.PrintCorrelationMatrix( className );
00140 }
00141 Log() << kINFO << " " << Endl;
00142 }
00143 return ds;
00144 }
00145
00146
00147 TMVA::DataSet* TMVA::DataSetFactory::BuildDynamicDataSet( TMVA::DataSetInfo& dsi )
00148 {
00149 Log() << kDEBUG << "Build DataSet consisting of one Event with dynamically changing variables" << Endl;
00150 DataSet* ds = new DataSet(dsi);
00151
00152
00153 if(dsi.GetNClasses()==0){
00154 dsi.AddClass( "data" );
00155 dsi.GetClassInfo( "data" )->SetNumber(0);
00156 }
00157
00158 std::vector<Float_t*>* evdyn = new std::vector<Float_t*>(0);
00159
00160 std::vector<VariableInfo>& varinfos = dsi.GetVariableInfos();
00161 std::vector<VariableInfo>::iterator it = varinfos.begin();
00162 for (;it!=varinfos.end();it++) evdyn->push_back( (Float_t*)(*it).GetExternalLink() );
00163
00164 std::vector<VariableInfo>& spectatorinfos = dsi.GetSpectatorInfos();
00165 it = spectatorinfos.begin();
00166 for (;it!=spectatorinfos.end();it++) evdyn->push_back( (Float_t*)(*it).GetExternalLink() );
00167
00168 TMVA::Event * ev = new Event((const std::vector<Float_t*>*&)evdyn, varinfos.size());
00169 std::vector<Event*>* newEventVector = new std::vector<Event*>;
00170 newEventVector->push_back(ev);
00171
00172 ds->SetEventCollection(newEventVector, Types::kTraining);
00173 ds->SetCurrentType( Types::kTraining );
00174 ds->SetCurrentEvent( 0 );
00175
00176 return ds;
00177 }
00178
00179
00180
00181 TMVA::DataSet* TMVA::DataSetFactory::BuildInitialDataSet( DataSetInfo& dsi, DataInputHandler& dataInput )
00182 {
00183
00184 if (dataInput.GetEntries()==0) return BuildDynamicDataSet( dsi );
00185
00186
00187
00188
00189 std::vector< TString >* classList = dataInput.GetClassList();
00190 for (std::vector<TString>::iterator it = classList->begin(); it< classList->end(); it++) {
00191 dsi.AddClass( (*it) );
00192 }
00193 delete classList;
00194
00195 TString normMode;
00196 TString splitMode;
00197 TString mixMode;
00198 UInt_t splitSeed;
00199
00200
00201
00202 TMVA::EventVectorOfClassesOfTreeType tmpEventVector;
00203 TMVA::NumberPerClassOfTreeType nTrainTestEvents;
00204
00205 InitOptions ( dsi, nTrainTestEvents, normMode, splitSeed, splitMode , mixMode );
00206 BuildEventVector( dsi, dataInput, tmpEventVector );
00207
00208 DataSet* ds = MixEvents( dsi, tmpEventVector, nTrainTestEvents, splitMode, mixMode, normMode, splitSeed);
00209
00210 const Bool_t showCollectedOutput = kFALSE;
00211 if (showCollectedOutput) {
00212 Int_t maxL = dsi.GetClassNameMaxLength();
00213 Log() << kINFO << "Collected:" << Endl;
00214 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00215 Log() << kINFO << " "
00216 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00217 << " training entries: " << ds->GetNClassEvents( 0, cl ) << Endl;
00218 Log() << kINFO << " "
00219 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00220 << " testing entries: " << ds->GetNClassEvents( 1, cl ) << Endl;
00221 }
00222 Log() << kINFO << " " << Endl;
00223 }
00224
00225 return ds;
00226 }
00227
00228
00229 Bool_t TMVA::DataSetFactory::CheckTTreeFormula( TTreeFormula* ttf, const TString& expression, Bool_t& hasDollar )
00230 {
00231
00232 Bool_t worked = kTRUE;
00233
00234 if( ttf->GetNdim() <= 0 )
00235 Log() << kFATAL << "Expression " << expression.Data() << " could not be resolved to a valid formula. " << Endl;
00236
00237
00238
00239
00240 if( ttf->GetNdata() == 0 ){
00241 Log() << kWARNING << "Expression: " << expression.Data()
00242 << " does not provide data for this event. "
00243 << "This event is not taken into account. --> please check if you use as a variable "
00244 << "an entry of an array which is not filled for some events "
00245 << "(e.g. arr[4] when arr has only 3 elements)." << Endl;
00246 Log() << kWARNING << "If you want to take the event into account you can do something like: "
00247 << "\"Alt$(arr[4],0)\" where in cases where arr doesn't have a 4th element, "
00248 << " 0 is taken as an alternative." << Endl;
00249 worked = kFALSE;
00250 }
00251 if( expression.Contains("$") ) hasDollar = kTRUE;
00252 return worked;
00253 }
00254
00255
00256 void TMVA::DataSetFactory::ChangeToNewTree( TreeInfo& tinfo, const DataSetInfo & dsi )
00257 {
00258
00259
00260
00261
00262
00263
00264 TTree *tr = tinfo.GetTree()->GetTree();
00265
00266 tr->SetBranchStatus("*",1);
00267
00268 Bool_t hasDollar = kFALSE;
00269
00270
00271 Log() << kDEBUG << "transform input variables" << Endl;
00272 std::vector<TTreeFormula*>::const_iterator formIt, formItEnd;
00273 for (formIt = fInputFormulas.begin(), formItEnd=fInputFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00274 fInputFormulas.clear();
00275 TTreeFormula* ttf = 0;
00276
00277 for (UInt_t i=0; i<dsi.GetNVariables(); i++) {
00278 ttf = new TTreeFormula( Form( "Formula%s", dsi.GetVariableInfo(i).GetInternalName().Data() ),
00279 dsi.GetVariableInfo(i).GetExpression().Data(), tr );
00280 CheckTTreeFormula( ttf, dsi.GetVariableInfo(i).GetExpression(), hasDollar );
00281 fInputFormulas.push_back( ttf );
00282 }
00283
00284
00285
00286
00287 Log() << kDEBUG << "transform regression targets" << Endl;
00288 for (formIt = fTargetFormulas.begin(), formItEnd = fTargetFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00289 fTargetFormulas.clear();
00290 for (UInt_t i=0; i<dsi.GetNTargets(); i++) {
00291 ttf = new TTreeFormula( Form( "Formula%s", dsi.GetTargetInfo(i).GetInternalName().Data() ),
00292 dsi.GetTargetInfo(i).GetExpression().Data(), tr );
00293 CheckTTreeFormula( ttf, dsi.GetTargetInfo(i).GetExpression(), hasDollar );
00294 fTargetFormulas.push_back( ttf );
00295 }
00296
00297
00298
00299
00300 Log() << kDEBUG << "transform spectator variables" << Endl;
00301 for (formIt = fSpectatorFormulas.begin(), formItEnd = fSpectatorFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00302 fSpectatorFormulas.clear();
00303 for (UInt_t i=0; i<dsi.GetNSpectators(); i++) {
00304 ttf = new TTreeFormula( Form( "Formula%s", dsi.GetSpectatorInfo(i).GetInternalName().Data() ),
00305 dsi.GetSpectatorInfo(i).GetExpression().Data(), tr );
00306 CheckTTreeFormula( ttf, dsi.GetSpectatorInfo(i).GetExpression(), hasDollar );
00307 fSpectatorFormulas.push_back( ttf );
00308 }
00309
00310
00311
00312
00313 Log() << kDEBUG << "transform cuts" << Endl;
00314 for (formIt = fCutFormulas.begin(), formItEnd = fCutFormulas.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00315 fCutFormulas.clear();
00316 for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
00317 const TCut& tmpCut = dsi.GetClassInfo(clIdx)->GetCut();
00318 const TString tmpCutExp(tmpCut.GetTitle());
00319 ttf = 0;
00320 if (tmpCutExp!="") {
00321 ttf = new TTreeFormula( Form("CutClass%i",clIdx), tmpCutExp, tr );
00322 Bool_t worked = CheckTTreeFormula( ttf, tmpCutExp, hasDollar );
00323 if( !worked ){
00324 Log() << kWARNING << "Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
00325 << "\" cut \"" << dsi.GetClassInfo(clIdx)->GetCut() << Endl;
00326 }
00327 }
00328 fCutFormulas.push_back( ttf );
00329 }
00330
00331
00332
00333
00334 Log() << kDEBUG << "transform weights" << Endl;
00335 for (formIt = fWeightFormula.begin(), formItEnd = fWeightFormula.end(); formIt!=formItEnd; formIt++) if (*formIt) delete *formIt;
00336 fWeightFormula.clear();
00337 for (UInt_t clIdx=0; clIdx<dsi.GetNClasses(); clIdx++) {
00338 const TString tmpWeight = dsi.GetClassInfo(clIdx)->GetWeight();
00339
00340 if (dsi.GetClassInfo(clIdx)->GetName() != tinfo.GetClassName() ) {
00341 fWeightFormula.push_back( 0 );
00342 continue;
00343 }
00344
00345 ttf = 0;
00346 if (tmpWeight!="") {
00347 ttf = new TTreeFormula( "FormulaWeight", tmpWeight, tr );
00348 Bool_t worked = CheckTTreeFormula( ttf, tmpWeight, hasDollar );
00349 if( !worked ){
00350 Log() << kWARNING << "Please check class \"" << dsi.GetClassInfo(clIdx)->GetName()
00351 << "\" weight \"" << dsi.GetClassInfo(clIdx)->GetWeight() << Endl;
00352 }
00353 }
00354 else {
00355 ttf = 0;
00356 }
00357 fWeightFormula.push_back( ttf );
00358 }
00359 Log() << kDEBUG << "enable branches" << Endl;
00360
00361
00362 if (!hasDollar) {
00363 tr->SetBranchStatus("*",0);
00364 Log() << kDEBUG << "enable branches: input variables" << Endl;
00365
00366 for (formIt = fInputFormulas.begin(); formIt!=fInputFormulas.end(); formIt++) {
00367 ttf = *formIt;
00368 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++) {
00369 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00370 }
00371 }
00372
00373 Log() << kDEBUG << "enable branches: targets" << Endl;
00374 for (formIt = fTargetFormulas.begin(); formIt!=fTargetFormulas.end(); formIt++) {
00375 ttf = *formIt;
00376 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00377 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00378 }
00379
00380 Log() << kDEBUG << "enable branches: spectators" << Endl;
00381 for (formIt = fSpectatorFormulas.begin(); formIt!=fSpectatorFormulas.end(); formIt++) {
00382 ttf = *formIt;
00383 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00384 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00385 }
00386
00387 Log() << kDEBUG << "enable branches: cuts" << Endl;
00388 for (formIt = fCutFormulas.begin(); formIt!=fCutFormulas.end(); formIt++) {
00389 ttf = *formIt;
00390 if (!ttf) continue;
00391 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00392 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00393 }
00394
00395 Log() << kDEBUG << "enable branches: weights" << Endl;
00396 for (formIt = fWeightFormula.begin(); formIt!=fWeightFormula.end(); formIt++) {
00397 ttf = *formIt;
00398 if (!ttf) continue;
00399 for (Int_t bi = 0; bi<ttf->GetNcodes(); bi++)
00400 tr->SetBranchStatus( ttf->GetLeaf(bi)->GetBranch()->GetName(), 1 );
00401 }
00402 }
00403 Log() << kDEBUG << "tree initialized" << Endl;
00404 return;
00405 }
00406
00407
00408 void TMVA::DataSetFactory::CalcMinMax( DataSet* ds, TMVA::DataSetInfo& dsi )
00409 {
00410
00411 const UInt_t nvar = ds->GetNVariables();
00412 const UInt_t ntgts = ds->GetNTargets();
00413 const UInt_t nvis = ds->GetNSpectators();
00414
00415 Float_t *min = new Float_t[nvar];
00416 Float_t *max = new Float_t[nvar];
00417 Float_t *tgmin = new Float_t[ntgts];
00418 Float_t *tgmax = new Float_t[ntgts];
00419 Float_t *vmin = new Float_t[nvis];
00420 Float_t *vmax = new Float_t[nvis];
00421
00422 for (UInt_t ivar=0; ivar<nvar ; ivar++) { min[ivar] = FLT_MAX; max[ivar] = -FLT_MAX; }
00423 for (UInt_t ivar=0; ivar<ntgts; ivar++) { tgmin[ivar] = FLT_MAX; tgmax[ivar] = -FLT_MAX; }
00424 for (UInt_t ivar=0; ivar<nvis; ivar++) { vmin[ivar] = FLT_MAX; vmax[ivar] = -FLT_MAX; }
00425
00426
00427
00428 for (Int_t i=0; i<ds->GetNEvents(); i++) {
00429 Event * ev = ds->GetEvent(i);
00430 for (UInt_t ivar=0; ivar<nvar; ivar++) {
00431 Double_t v = ev->GetValue(ivar);
00432 if (v<min[ivar]) min[ivar] = v;
00433 if (v>max[ivar]) max[ivar] = v;
00434 }
00435 for (UInt_t itgt=0; itgt<ntgts; itgt++) {
00436 Double_t v = ev->GetTarget(itgt);
00437 if (v<tgmin[itgt]) tgmin[itgt] = v;
00438 if (v>tgmax[itgt]) tgmax[itgt] = v;
00439 }
00440 for (UInt_t ivis=0; ivis<nvis; ivis++) {
00441 Double_t v = ev->GetSpectator(ivis);
00442 if (v<vmin[ivis]) vmin[ivis] = v;
00443 if (v>vmax[ivis]) vmax[ivis] = v;
00444 }
00445 }
00446
00447 for (UInt_t ivar=0; ivar<nvar; ivar++) {
00448 dsi.GetVariableInfo(ivar).SetMin(min[ivar]);
00449 dsi.GetVariableInfo(ivar).SetMax(max[ivar]);
00450 if( TMath::Abs(max[ivar]-min[ivar]) <= FLT_MIN )
00451 Log() << kFATAL << "Variable " << dsi.GetVariableInfo(ivar).GetExpression().Data() << " is constant. Please remove the variable." << Endl;
00452 }
00453 for (UInt_t ivar=0; ivar<ntgts; ivar++) {
00454 dsi.GetTargetInfo(ivar).SetMin(tgmin[ivar]);
00455 dsi.GetTargetInfo(ivar).SetMax(tgmax[ivar]);
00456 if( TMath::Abs(tgmax[ivar]-tgmin[ivar]) <= FLT_MIN )
00457 Log() << kFATAL << "Target " << dsi.GetTargetInfo(ivar).GetExpression().Data() << " is constant. Please remove the variable." << Endl;
00458 }
00459 for (UInt_t ivar=0; ivar<nvis; ivar++) {
00460 dsi.GetSpectatorInfo(ivar).SetMin(vmin[ivar]);
00461 dsi.GetSpectatorInfo(ivar).SetMax(vmax[ivar]);
00462
00463
00464 }
00465 delete [] min;
00466 delete [] max;
00467 delete [] tgmin;
00468 delete [] tgmax;
00469 delete [] vmin;
00470 delete [] vmax;
00471 }
00472
00473
00474 TMatrixD* TMVA::DataSetFactory::CalcCorrelationMatrix( DataSet* ds, const UInt_t classNumber )
00475 {
00476
00477
00478
00479
00480
00481 TMatrixD* mat = CalcCovarianceMatrix( ds, classNumber );
00482
00483
00484 UInt_t nvar = ds->GetNVariables(), ivar, jvar;
00485
00486 for (ivar=0; ivar<nvar; ivar++) {
00487 for (jvar=0; jvar<nvar; jvar++) {
00488 if (ivar != jvar) {
00489 Double_t d = (*mat)(ivar, ivar)*(*mat)(jvar, jvar);
00490 if (d > 0) (*mat)(ivar, jvar) /= sqrt(d);
00491 else {
00492 Log() << kWARNING << "<GetCorrelationMatrix> Zero variances for variables "
00493 << "(" << ivar << ", " << jvar << ") = " << d
00494 << Endl;
00495 (*mat)(ivar, jvar) = 0;
00496 }
00497 }
00498 }
00499 }
00500
00501 for (ivar=0; ivar<nvar; ivar++) (*mat)(ivar, ivar) = 1.0;
00502
00503 return mat;
00504 }
00505
00506
00507 TMatrixD* TMVA::DataSetFactory::CalcCovarianceMatrix( DataSet * ds, const UInt_t classNumber )
00508 {
00509
00510
00511 UInt_t nvar = ds->GetNVariables();
00512 UInt_t ivar = 0, jvar = 0;
00513
00514 TMatrixD* mat = new TMatrixD( nvar, nvar );
00515
00516
00517 TVectorD vec(nvar);
00518 TMatrixD mat2(nvar, nvar);
00519 for (ivar=0; ivar<nvar; ivar++) {
00520 vec(ivar) = 0;
00521 for (jvar=0; jvar<nvar; jvar++) mat2(ivar, jvar) = 0;
00522 }
00523
00524
00525 Double_t ic = 0;
00526 for (Int_t i=0; i<ds->GetNEvents(); i++) {
00527
00528 Event * ev = ds->GetEvent(i);
00529 if (ev->GetClass() != classNumber ) continue;
00530
00531 Double_t weight = ev->GetWeight();
00532 ic += weight;
00533
00534 for (ivar=0; ivar<nvar; ivar++) {
00535
00536 Double_t xi = ev->GetValue(ivar);
00537 vec(ivar) += xi*weight;
00538 mat2(ivar, ivar) += (xi*xi*weight);
00539
00540 for (jvar=ivar+1; jvar<nvar; jvar++) {
00541 Double_t xj = ev->GetValue(jvar);
00542 mat2(ivar, jvar) += (xi*xj*weight);
00543 }
00544 }
00545 }
00546
00547 for (ivar=0; ivar<nvar; ivar++)
00548 for (jvar=ivar+1; jvar<nvar; jvar++)
00549 mat2(jvar, ivar) = mat2(ivar, jvar);
00550
00551
00552
00553 for (ivar=0; ivar<nvar; ivar++) {
00554 for (jvar=0; jvar<nvar; jvar++) {
00555 (*mat)(ivar, jvar) = mat2(ivar, jvar)/ic - vec(ivar)*vec(jvar)/(ic*ic);
00556 }
00557 }
00558
00559 return mat;
00560 }
00561
00562
00563
00564
00565 void TMVA::DataSetFactory::InitOptions( TMVA::DataSetInfo& dsi,
00566 TMVA::NumberPerClassOfTreeType& nTrainTestEvents,
00567 TString& normMode, UInt_t& splitSeed,
00568 TString& splitMode,
00569 TString& mixMode )
00570 {
00571
00572 Configurable splitSpecs( dsi.GetSplitOptions() );
00573 splitSpecs.SetConfigName("DataSetFactory");
00574 splitSpecs.SetConfigDescription( "Configuration options given in the \"PrepareForTrainingAndTesting\" call; these options define the creation of the data sets used for training and expert validation by TMVA" );
00575
00576 splitMode = "Random";
00577 splitSpecs.DeclareOptionRef( splitMode, "SplitMode",
00578 "Method of picking training and testing events (default: random)" );
00579 splitSpecs.AddPreDefVal(TString("Random"));
00580 splitSpecs.AddPreDefVal(TString("Alternate"));
00581 splitSpecs.AddPreDefVal(TString("Block"));
00582
00583 mixMode = "SameAsSplitMode";
00584 splitSpecs.DeclareOptionRef( mixMode, "MixMode",
00585 "Method of mixing events of differnt classes into one dataset (default: SameAsSplitMode)" );
00586 splitSpecs.AddPreDefVal(TString("SameAsSplitMode"));
00587 splitSpecs.AddPreDefVal(TString("Random"));
00588 splitSpecs.AddPreDefVal(TString("Alternate"));
00589 splitSpecs.AddPreDefVal(TString("Block"));
00590
00591 splitSeed = 100;
00592 splitSpecs.DeclareOptionRef( splitSeed, "SplitSeed",
00593 "Seed for random event shuffling" );
00594
00595 normMode = "NumEvents";
00596 splitSpecs.DeclareOptionRef( normMode, "NormMode",
00597 "Overall renormalisation of event-by-event weights (NumEvents: average weight of 1 per event, independently for signal and background; EqualNumEvents: average weight of 1 per event for signal, and sum of weights for background equal to sum of weights for signal)" );
00598 splitSpecs.AddPreDefVal(TString("None"));
00599 splitSpecs.AddPreDefVal(TString("NumEvents"));
00600 splitSpecs.AddPreDefVal(TString("EqualNumEvents"));
00601
00602
00603
00604
00605 nTrainTestEvents.insert( TMVA::NumberPerClassOfTreeType::value_type( Types::kTraining, TMVA::NumberPerClass( dsi.GetNClasses() ) ) );
00606 nTrainTestEvents.insert( TMVA::NumberPerClassOfTreeType::value_type( Types::kTesting, TMVA::NumberPerClass( dsi.GetNClasses() ) ) );
00607
00608
00609 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00610 nTrainTestEvents[Types::kTraining].at(cl) = 0;
00611 nTrainTestEvents[Types::kTesting].at(cl) = 0;
00612
00613 TString clName = dsi.GetClassInfo(cl)->GetName();
00614 TString titleTrain = TString().Format("Number of training events of class %s (default: 0 = all)",clName.Data()).Data();
00615 TString titleTest = TString().Format("Number of test events of class %s (default: 0 = all)",clName.Data()).Data();
00616
00617 splitSpecs.DeclareOptionRef( nTrainTestEvents[Types::kTraining].at(cl) , TString("nTrain_")+clName, titleTrain );
00618 splitSpecs.DeclareOptionRef( nTrainTestEvents[Types::kTesting].at(cl) , TString("nTest_")+clName , titleTest );
00619 }
00620
00621 splitSpecs.DeclareOptionRef( fVerbose, "V", "Verbosity (default: true)" );
00622
00623 splitSpecs.DeclareOptionRef( fVerboseLevel=TString("Info"), "VerboseLevel", "VerboseLevel (Debug/Verbose/Info)" );
00624 splitSpecs.AddPreDefVal(TString("Debug"));
00625 splitSpecs.AddPreDefVal(TString("Verbose"));
00626 splitSpecs.AddPreDefVal(TString("Info"));
00627
00628 splitSpecs.ParseOptions();
00629 splitSpecs.CheckForUnusedOptions();
00630
00631
00632 if (Verbose()) fLogger->SetMinType( kVERBOSE );
00633 if (fVerboseLevel.CompareTo("Debug") ==0) fLogger->SetMinType( kDEBUG );
00634 if (fVerboseLevel.CompareTo("Verbose") ==0) fLogger->SetMinType( kVERBOSE );
00635 if (fVerboseLevel.CompareTo("Info") ==0) fLogger->SetMinType( kINFO );
00636
00637
00638 splitMode.ToUpper(); mixMode.ToUpper(); normMode.ToUpper();
00639
00640 Log() << kINFO << "Splitmode is: \"" << splitMode << "\" the mixmode is: \"" << mixMode << "\"" << Endl;
00641 if (mixMode=="SAMEASSPLITMODE") mixMode = splitMode;
00642 else if (mixMode!=splitMode)
00643 Log() << kINFO << "DataSet splitmode="<<splitMode
00644 <<" differs from mixmode="<<mixMode<<Endl;
00645 }
00646
00647
00648
00649 void TMVA::DataSetFactory::BuildEventVector( TMVA::DataSetInfo& dsi,
00650 TMVA::DataInputHandler& dataInput,
00651 TMVA::EventVectorOfClassesOfTreeType& tmpEventVector )
00652 {
00653
00654
00655
00656 tmpEventVector.insert( std::make_pair(Types::kTraining ,TMVA::EventVectorOfClasses(dsi.GetNClasses() ) ) );
00657 tmpEventVector.insert( std::make_pair(Types::kTesting ,TMVA::EventVectorOfClasses(dsi.GetNClasses() ) ) );
00658 tmpEventVector.insert( std::make_pair(Types::kMaxTreeType,TMVA::EventVectorOfClasses(dsi.GetNClasses() ) ) );
00659
00660
00661
00662 const UInt_t nvars = dsi.GetNVariables();
00663 const UInt_t ntgts = dsi.GetNTargets();
00664 const UInt_t nvis = dsi.GetNSpectators();
00665
00666
00667
00668 std::vector< Int_t > nInitialEvents( dsi.GetNClasses() );
00669 std::vector< Int_t > nEvBeforeCut( dsi.GetNClasses() );
00670 std::vector< Int_t > nEvAfterCut( dsi.GetNClasses() );
00671 std::vector< Float_t > nWeEvBeforeCut( dsi.GetNClasses() );
00672 std::vector< Float_t > nWeEvAfterCut( dsi.GetNClasses() );
00673 std::vector< Double_t > nNegWeights( dsi.GetNClasses() );
00674 std::vector< Float_t* > varAvLength( dsi.GetNClasses() );
00675
00676 Bool_t haveArrayVariable = kFALSE;
00677 Bool_t *varIsArray = new Bool_t[nvars];
00678
00679 for (size_t i=0; i<varAvLength.size(); i++) {
00680 varAvLength[i] = new Float_t[nvars];
00681 for (UInt_t ivar=0; ivar<nvars; ivar++) {
00682
00683 varAvLength[i][ivar] = 0;
00684 }
00685 }
00686
00687
00688
00689 for (UInt_t cl=0; cl<dsi.GetNClasses(); cl++) {
00690
00691 Log() << kINFO << "Create training and testing trees -- looping over class \""
00692 << dsi.GetClassInfo(cl)->GetName() << "\" ..." << Endl;
00693
00694
00695 const TString tmpWeight = dsi.GetClassInfo(cl)->GetWeight();
00696 if (tmpWeight!="") {
00697 Log() << kINFO << "Weight expression for class \"" << dsi.GetClassInfo(cl)->GetName() << "\": \""
00698 << tmpWeight << "\"" << Endl;
00699 }
00700 else {
00701 Log() << kINFO << "No weight expression defined for class \"" << dsi.GetClassInfo(cl)->GetName()
00702 << "\"" << Endl;
00703 }
00704
00705
00706 TString currentFileName("");
00707
00708 std::vector<TreeInfo>::const_iterator treeIt(dataInput.begin(dsi.GetClassInfo(cl)->GetName()));
00709 for (;treeIt!=dataInput.end(dsi.GetClassInfo(cl)->GetName()); treeIt++) {
00710
00711
00712 std::vector<Float_t> vars(nvars);
00713 std::vector<Float_t> tgts(ntgts);
00714 std::vector<Float_t> vis(nvis);
00715 TreeInfo currentInfo = *treeIt;
00716
00717 Bool_t isChain = (TString("TChain") == currentInfo.GetTree()->ClassName());
00718 currentInfo.GetTree()->LoadTree(0);
00719 ChangeToNewTree( currentInfo, dsi );
00720
00721
00722 nInitialEvents.at(cl) += currentInfo.GetTree()->GetEntries();
00723
00724
00725
00726
00727
00728
00729
00730 for (Long64_t evtIdx = 0; evtIdx < currentInfo.GetTree()->GetEntries(); evtIdx++) {
00731 currentInfo.GetTree()->LoadTree(evtIdx);
00732
00733
00734 if (isChain) {
00735 if (currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName() != currentFileName) {
00736 currentFileName = currentInfo.GetTree()->GetTree()->GetDirectory()->GetFile()->GetName();
00737 ChangeToNewTree( currentInfo, dsi );
00738 }
00739 }
00740 currentInfo.GetTree()->GetEntry(evtIdx);
00741 Int_t sizeOfArrays = 1;
00742 Int_t prevArrExpr = 0;
00743
00744
00745
00746
00747 for (UInt_t ivar=0; ivar<nvars; ivar++) {
00748 Int_t ndata = fInputFormulas[ivar]->GetNdata();
00749 varAvLength[cl][ivar] += ndata;
00750 if (ndata == 1) continue;
00751 haveArrayVariable = kTRUE;
00752 varIsArray[ivar] = kTRUE;
00753 if (sizeOfArrays == 1) {
00754 sizeOfArrays = ndata;
00755 prevArrExpr = ivar;
00756 }
00757 else if (sizeOfArrays!=ndata) {
00758 Log() << kERROR << "ERROR while preparing training and testing trees:" << Endl;
00759 Log() << " multiple array-type expressions of different length were encountered" << Endl;
00760 Log() << " location of error: event " << evtIdx
00761 << " in tree " << currentInfo.GetTree()->GetName()
00762 << " of file " << currentInfo.GetTree()->GetCurrentFile()->GetName() << Endl;
00763 Log() << " expression " << fInputFormulas[ivar]->GetTitle() << " has "
00764 << ndata << " entries, while" << Endl;
00765 Log() << " expression " << fInputFormulas[prevArrExpr]->GetTitle() << " has "
00766 << fInputFormulas[prevArrExpr]->GetNdata() << " entries" << Endl;
00767 Log() << kFATAL << "Need to abort" << Endl;
00768 }
00769 }
00770
00771
00772 for (Int_t idata = 0; idata<sizeOfArrays; idata++) {
00773 Bool_t containsNaN = kFALSE;
00774
00775 TTreeFormula* formula = 0;
00776
00777
00778 Float_t cutVal = 1;
00779 formula = fCutFormulas[cl];
00780 if (formula) {
00781 Int_t ndata = formula->GetNdata();
00782 cutVal = (ndata==1 ?
00783 formula->EvalInstance(0) :
00784 formula->EvalInstance(idata));
00785 if (TMath::IsNaN(cutVal)) {
00786 containsNaN = kTRUE;
00787 Log() << kWARNING << "Cut expression resolves to infinite value (NaN): "
00788 << formula->GetTitle() << Endl;
00789 }
00790 }
00791
00792
00793 for (UInt_t ivar=0; ivar<nvars; ivar++) {
00794 formula = fInputFormulas[ivar];
00795 Int_t ndata = formula->GetNdata();
00796 vars[ivar] = (ndata == 1 ?
00797 formula->EvalInstance(0) :
00798 formula->EvalInstance(idata));
00799 if (TMath::IsNaN(vars[ivar])) {
00800 containsNaN = kTRUE;
00801 Log() << kWARNING << "Input expression resolves to infinite value (NaN): "
00802 << formula->GetTitle() << Endl;
00803 }
00804 }
00805
00806
00807 for (UInt_t itrgt=0; itrgt<ntgts; itrgt++) {
00808 formula = fTargetFormulas[itrgt];
00809 Int_t ndata = formula->GetNdata();
00810 tgts[itrgt] = (ndata == 1 ?
00811 formula->EvalInstance(0) :
00812 formula->EvalInstance(idata));
00813 if (TMath::IsNaN(tgts[itrgt])) {
00814 containsNaN = kTRUE;
00815 Log() << kWARNING << "Target expression resolves to infinite value (NaN): "
00816 << formula->GetTitle() << Endl;
00817 }
00818 }
00819
00820
00821 for (UInt_t itVis=0; itVis<nvis; itVis++) {
00822 formula = fSpectatorFormulas[itVis];
00823 Int_t ndata = formula->GetNdata();
00824 vis[itVis] = (ndata == 1 ?
00825 formula->EvalInstance(0) :
00826 formula->EvalInstance(idata));
00827 if (TMath::IsNaN(vis[itVis])) {
00828 containsNaN = kTRUE;
00829 Log() << kWARNING << "Spectator expression resolves to infinite value (NaN): "
00830 << formula->GetTitle() << Endl;
00831 }
00832 }
00833
00834
00835
00836 Float_t weight = currentInfo.GetWeight();
00837 formula = fWeightFormula[cl];
00838 if (formula!=0) {
00839 Int_t ndata = formula->GetNdata();
00840 weight *= (ndata == 1 ?
00841 formula->EvalInstance() :
00842 formula->EvalInstance(idata));
00843 if (TMath::IsNaN(weight)) {
00844 containsNaN = kTRUE;
00845 Log() << kWARNING << "Weight expression resolves to infinite value (NaN): "
00846 << formula->GetTitle() << Endl;
00847 }
00848 }
00849
00850
00851
00852 nEvBeforeCut.at(cl) ++;
00853 if (!TMath::IsNaN(weight))
00854 nWeEvBeforeCut.at(cl) += weight;
00855
00856
00857
00858 if (cutVal<0.5) continue;
00859
00860
00861
00862 if (weight < 0) nNegWeights.at(cl)++;
00863
00864
00865
00866 if (containsNaN) {
00867 Log() << kWARNING << "Event " << evtIdx;
00868 if (sizeOfArrays>1) Log() << kWARNING << " rejected" << Endl;
00869 continue;
00870 }
00871
00872
00873
00874 nEvAfterCut.at(cl) ++;
00875 nWeEvAfterCut.at(cl) += weight;
00876
00877
00878 tmpEventVector.find(currentInfo.GetTreeType())->second.at(cl).push_back(new Event(vars, tgts , vis, cl , weight));
00879
00880 }
00881 }
00882
00883 currentInfo.GetTree()->ResetBranchAddresses();
00884 }
00885
00886
00887
00888 }
00889
00890
00891 Int_t maxL = dsi.GetClassNameMaxLength();
00892
00893 Log() << kINFO << "Number of events in input trees (after possible flattening of arrays):" << Endl;
00894 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00895 Log() << kINFO << " "
00896 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00897 << " -- number of events : "
00898 << std::setw(5) << nEvBeforeCut.at(cl)
00899 << " / sum of weights: " << std::setw(5) << nWeEvBeforeCut.at(cl) << Endl;
00900 }
00901
00902 for (UInt_t cl = 0; cl < dsi.GetNClasses(); cl++) {
00903 Log() << kINFO << " " << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00904 <<" tree -- total number of entries: "
00905 << std::setw(5) << dataInput.GetEntries(dsi.GetClassInfo(cl)->GetName()) << Endl;
00906 }
00907
00908 Log() << kINFO << "Preselection:" << Endl;
00909 if (dsi.HasCuts()) {
00910 for (UInt_t cl = 0; cl< dsi.GetNClasses(); cl++) {
00911 Log() << kINFO << " " << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00912 << " requirement: \"" << dsi.GetClassInfo(cl)->GetCut() << "\"" << Endl;
00913 Log() << kINFO << " "
00914 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00915 << " -- number of events passed: "
00916 << std::setw(5) << nEvAfterCut.at(cl)
00917 << " / sum of weights: " << std::setw(5) << nWeEvAfterCut.at(cl) << Endl;
00918 Log() << kINFO << " "
00919 << setiosflags(ios::left) << std::setw(maxL) << dsi.GetClassInfo(cl)->GetName()
00920 << " -- efficiency : "
00921 << std::setw(6) << nWeEvAfterCut.at(cl)/nWeEvBeforeCut.at(cl) << Endl;
00922 }
00923 }
00924 else Log() << kINFO << " No preselection cuts applied on event classes" << Endl;
00925
00926 delete[] varIsArray;
00927 for (size_t i=0; i<varAvLength.size(); i++)
00928 delete[] varAvLength[i];
00929
00930 }
00931
00932
00933 TMVA::DataSet* TMVA::DataSetFactory::MixEvents( DataSetInfo& dsi,
00934 TMVA::EventVectorOfClassesOfTreeType& tmpEventVector,
00935 TMVA::NumberPerClassOfTreeType& nTrainTestEvents,
00936 const TString& splitMode,
00937 const TString& mixMode,
00938 const TString& normMode,
00939 UInt_t splitSeed)
00940 {
00941
00942 Bool_t emptyUndefined = kTRUE;
00943
00944
00945 for( Int_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
00946 emptyUndefined &= tmpEventVector[Types::kMaxTreeType].at(cls).empty();
00947 }
00948
00949 TMVA::RandomGenerator rndm( splitSeed );
00950
00951
00952
00953
00954 if (splitMode.Contains( "RANDOM" ) && !emptyUndefined ) {
00955 Log() << kDEBUG << "randomly shuffling events which are not yet associated to testing or training"<<Endl;
00956
00957 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
00958 std::random_shuffle(tmpEventVector[Types::kMaxTreeType].at(cls).begin(),
00959 tmpEventVector[Types::kMaxTreeType].at(cls).end(),
00960 rndm );
00961 }
00962 }
00963
00964
00965 Log() << kDEBUG << "SPLITTING ========" << Endl;
00966 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
00967 Log() << kDEBUG << "---- class " << cls << Endl;
00968 Log() << kDEBUG << "check number of training/testing events, requested and available number of events and for class " << cls << Endl;
00969
00970
00971 EventVector& eventVectorTraining = tmpEventVector.find( Types::kTraining )->second.at(cls);
00972 EventVector& eventVectorTesting = tmpEventVector.find( Types::kTesting )->second.at(cls);
00973 EventVector& eventVectorUndefined= tmpEventVector.find( Types::kMaxTreeType )->second.at(cls);
00974
00975 Int_t alreadyAvailableTraining = eventVectorTraining.size();
00976 Int_t alreadyAvailableTesting = eventVectorTesting.size();
00977 Int_t availableUndefined = eventVectorUndefined.size();
00978
00979 Int_t requestedTraining = nTrainTestEvents.find( Types::kTraining )->second.at(cls);
00980 Int_t requestedTesting = nTrainTestEvents.find( Types::kTesting )->second.at(cls);
00981
00982 Log() << kDEBUG << "availableTraining " << alreadyAvailableTraining << Endl;
00983 Log() << kDEBUG << "availableTesting " << alreadyAvailableTesting << Endl;
00984 Log() << kDEBUG << "availableUndefined " << availableUndefined << Endl;
00985 Log() << kDEBUG << "requestedTraining " << requestedTraining << Endl;
00986 Log() << kDEBUG << "requestedTesting " << requestedTesting << Endl;
00987
00988
00989
00990
00991
00992
00993
00994
00995
00996
00997
00998
00999
01000
01001
01002
01003
01004
01005
01006
01007
01008
01009
01010
01011
01012
01013
01014
01015
01016
01017
01018
01019
01020
01021
01022
01023
01024
01025
01026
01027
01028
01029
01030
01031
01032
01033
01034 int useForTesting,useForTraining;
01035 if( (requestedTraining == 0) && (requestedTesting == 0)){
01036
01037 Log() << kDEBUG << "requested 0" << Endl;
01038
01039 Int_t NFree = availableUndefined - TMath::Abs(alreadyAvailableTraining - alreadyAvailableTesting);
01040 if (NFree >=0){
01041 requestedTraining = TMath::Max(alreadyAvailableTraining,alreadyAvailableTesting) + NFree/2;
01042 requestedTesting = availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - requestedTraining;
01043 } else if (alreadyAvailableTraining > alreadyAvailableTesting){
01044 requestedTraining = alreadyAvailableTraining;
01045 requestedTesting = alreadyAvailableTesting +availableUndefined;
01046 }
01047 else {
01048 requestedTraining = alreadyAvailableTraining+availableUndefined;
01049 requestedTesting = alreadyAvailableTesting;
01050 }
01051 useForTraining = requestedTraining;
01052 useForTesting = requestedTesting;
01053 }
01054 else if ((requestedTesting == 0)){
01055 useForTraining = TMath::Max(requestedTraining,alreadyAvailableTraining);
01056 useForTesting= availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - useForTraining;
01057 requestedTesting = useForTesting;
01058 }
01059 else if ((requestedTraining == 0)){
01060 useForTesting = TMath::Max(requestedTesting,alreadyAvailableTesting);
01061 useForTraining= availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - useForTesting;
01062 requestedTraining = useForTraining;
01063 }
01064 else{
01065 int NFree = availableUndefined-TMath::Max(requestedTraining-alreadyAvailableTraining,0)-TMath::Max(requestedTesting-alreadyAvailableTesting,0);
01066 if (NFree <0) NFree = 0;
01067 useForTraining = TMath::Max(requestedTraining,alreadyAvailableTraining) + NFree/2;
01068 useForTesting= availableUndefined+alreadyAvailableTraining+alreadyAvailableTesting - useForTraining;
01069 }
01070 Log() << kDEBUG << "determined event sample size to select training sample from="<<useForTraining<<Endl;
01071 Log() << kDEBUG << "determined event sample size to select test sample from="<<useForTesting<<Endl;
01072
01073
01074
01075 if( splitMode == "ALTERNATE" ){
01076 Log() << kDEBUG << "split 'ALTERNATE'" << Endl;
01077 Int_t nTraining = alreadyAvailableTraining;
01078 Int_t nTesting = alreadyAvailableTesting;
01079 for( EventVector::iterator it = eventVectorUndefined.begin(), itEnd = eventVectorUndefined.end(); it != itEnd; ){
01080 ++nTraining;
01081 if( nTraining <= requestedTraining ){
01082 eventVectorTraining.insert( eventVectorTraining.end(), (*it) );
01083 ++it;
01084 }
01085 if( it != itEnd ){
01086 ++nTesting;
01087 eventVectorTesting.insert( eventVectorTesting.end(), (*it) );
01088 ++it;
01089 }
01090 }
01091 }else{
01092 Log() << kDEBUG << "split '" << splitMode << "'" << Endl;
01093
01094
01095 Log() << kDEBUG << "availableundefined : " << availableUndefined << Endl;
01096 Log() << kDEBUG << "useForTraining : " << useForTraining << Endl;
01097 Log() << kDEBUG << "useForTesting : " << useForTesting << Endl;
01098 Log() << kDEBUG << "alreadyAvailableTraining : " << alreadyAvailableTraining << Endl;
01099 Log() << kDEBUG << "alreadyAvailableTesting : " << alreadyAvailableTesting << Endl;
01100
01101 if( availableUndefined<(useForTraining-alreadyAvailableTraining) ||
01102 availableUndefined<(useForTesting -alreadyAvailableTesting ) ||
01103 availableUndefined<(useForTraining+useForTesting-alreadyAvailableTraining-alreadyAvailableTesting ) ){
01104 Log() << kFATAL << "More events requested than available!" << Endl;
01105 }
01106
01107
01108 if (useForTraining>alreadyAvailableTraining){
01109 eventVectorTraining.insert( eventVectorTraining.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTraining- alreadyAvailableTraining );
01110 eventVectorUndefined.erase( eventVectorUndefined.begin(), eventVectorUndefined.begin() + useForTraining- alreadyAvailableTraining);
01111 }
01112 if (useForTesting>alreadyAvailableTesting){
01113 eventVectorTesting.insert( eventVectorTesting.end() , eventVectorUndefined.begin(), eventVectorUndefined.begin()+ useForTesting- alreadyAvailableTesting );
01114 }
01115 }
01116 eventVectorUndefined.clear();
01117
01118 if (splitMode.Contains( "RANDOM" )){
01119 UInt_t sizeTraining = eventVectorTraining.size();
01120 if( sizeTraining > UInt_t(requestedTraining) ){
01121 std::vector<UInt_t> indicesTraining( sizeTraining );
01122
01123 std::generate( indicesTraining.begin(), indicesTraining.end(), TMVA::Increment<UInt_t>(0) );
01124
01125 std::random_shuffle( indicesTraining.begin(), indicesTraining.end(), rndm );
01126
01127 indicesTraining.erase( indicesTraining.begin()+sizeTraining-UInt_t(requestedTraining), indicesTraining.end() );
01128
01129 for( std::vector<UInt_t>::iterator it = indicesTraining.begin(), itEnd = indicesTraining.end(); it != itEnd; ++it ){
01130 delete eventVectorTraining.at( (*it) );
01131 eventVectorTraining.at( (*it) ) = NULL;
01132 }
01133
01134 eventVectorTraining.erase( std::remove( eventVectorTraining.begin(), eventVectorTraining.end(), (void*)NULL ), eventVectorTraining.end() );
01135 }
01136
01137 UInt_t sizeTesting = eventVectorTesting.size();
01138 if( sizeTesting > UInt_t(requestedTesting) ){
01139 std::vector<UInt_t> indicesTesting( sizeTesting );
01140
01141 std::generate( indicesTesting.begin(), indicesTesting.end(), TMVA::Increment<UInt_t>(0) );
01142
01143 std::random_shuffle( indicesTesting.begin(), indicesTesting.end(), rndm );
01144
01145 indicesTesting.erase( indicesTesting.begin()+sizeTesting-UInt_t(requestedTesting), indicesTesting.end() );
01146
01147 for( std::vector<UInt_t>::iterator it = indicesTesting.begin(), itEnd = indicesTesting.end(); it != itEnd; ++it ){
01148 delete eventVectorTesting.at( (*it) );
01149 eventVectorTesting.at( (*it) ) = NULL;
01150 }
01151
01152 eventVectorTesting.erase( std::remove( eventVectorTesting.begin(), eventVectorTesting.end(), (void*)NULL ), eventVectorTesting.end() );
01153 }
01154 }
01155 else {
01156 if( eventVectorTraining.size() < UInt_t(requestedTraining) )
01157 Log() << kWARNING << "DataSetFactory/requested number of training samples larger than size of eventVectorTraining.\n"
01158 << "There is probably an issue. Please contact the TMVA developers." << Endl;
01159 std::for_each( eventVectorTraining.begin()+requestedTraining, eventVectorTraining.end(), DeleteFunctor<Event>() );
01160 eventVectorTraining.erase(eventVectorTraining.begin()+requestedTraining,eventVectorTraining.end());
01161
01162 if( eventVectorTesting.size() < UInt_t(requestedTesting) )
01163 Log() << kWARNING << "DataSetFactory/requested number of testing samples larger than size of eventVectorTesting.\n"
01164 << "There is probably an issue. Please contact the TMVA developers." << Endl;
01165 std::for_each( eventVectorTesting.begin()+requestedTesting, eventVectorTesting.end(), DeleteFunctor<Event>() );
01166 eventVectorTesting.erase(eventVectorTesting.begin()+requestedTesting,eventVectorTesting.end());
01167 }
01168 }
01169
01170 TMVA::DataSetFactory::RenormEvents( dsi, tmpEventVector, normMode );
01171
01172 Int_t trainingSize = 0;
01173 Int_t testingSize = 0;
01174
01175
01176 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
01177 trainingSize += tmpEventVector[Types::kTraining].at(cls).size();
01178 testingSize += tmpEventVector[Types::kTesting].at(cls).size();
01179 }
01180
01181
01182
01183
01184 EventVector* trainingEventVector = new EventVector();
01185 EventVector* testingEventVector = new EventVector();
01186
01187 trainingEventVector->reserve( trainingSize );
01188 testingEventVector->reserve( testingSize );
01189
01190
01191
01192
01193
01194 Log() << kDEBUG << " MIXING ============= " << Endl;
01195
01196 if( mixMode == "ALTERNATE" ){
01197
01198
01199 for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
01200 if (tmpEventVector[Types::kTraining].at(cls).size() != tmpEventVector[Types::kTraining].at(0).size()){
01201 Log() << kINFO << "Training sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
01202 }
01203 if (tmpEventVector[Types::kTesting].at(cls).size() != tmpEventVector[Types::kTesting].at(0).size()){
01204 Log() << kINFO << "Testing sample: You are trying to mix events in alternate mode although the classes have different event numbers. This works but the alternation stops at the last event of the smaller class."<<Endl;
01205 }
01206 }
01207 typedef EventVector::iterator EvtVecIt;
01208 EvtVecIt itEvent, itEventEnd;
01209
01210
01211 Log() << kDEBUG << "insert class 0 into training and test vector" << Endl;
01212 trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(0).begin(), tmpEventVector[Types::kTraining].at(0).end() );
01213 testingEventVector->insert( testingEventVector->end(), tmpEventVector[Types::kTesting].at(0).begin(), tmpEventVector[Types::kTesting].at(0).end() );
01214
01215
01216 EvtVecIt itTarget;
01217 for( UInt_t cls = 1; cls < dsi.GetNClasses(); ++cls ){
01218 Log() << kDEBUG << "insert class " << cls << Endl;
01219
01220 itTarget = trainingEventVector->begin() - 1;
01221
01222 for( itEvent = tmpEventVector[Types::kTraining].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTraining].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
01223
01224 if( (trainingEventVector->end() - itTarget) < Int_t(cls+1) ) {
01225 itTarget = trainingEventVector->end();
01226 trainingEventVector->insert( itTarget, itEvent, itEventEnd );
01227 break;
01228 }else{
01229 itTarget += cls+1;
01230 trainingEventVector->insert( itTarget, (*itEvent) );
01231 }
01232 }
01233
01234 itTarget = testingEventVector->begin() - 1;
01235
01236 for( itEvent = tmpEventVector[Types::kTesting].at(cls).begin(), itEventEnd = tmpEventVector[Types::kTesting].at(cls).end(); itEvent != itEventEnd; ++itEvent ){
01237
01238 if( ( testingEventVector->end() - itTarget ) < Int_t(cls+1) ) {
01239 itTarget = testingEventVector->end();
01240 testingEventVector->insert( itTarget, itEvent, itEventEnd );
01241 break;
01242 }else{
01243 itTarget += cls+1;
01244 testingEventVector->insert( itTarget, (*itEvent) );
01245 }
01246 }
01247 }
01248
01249
01250
01251
01252
01253
01254
01255
01256
01257
01258
01259 }else{
01260 for( UInt_t cls = 0; cls < dsi.GetNClasses(); ++cls ){
01261 trainingEventVector->insert( trainingEventVector->end(), tmpEventVector[Types::kTraining].at(cls).begin(), tmpEventVector[Types::kTraining].at(cls).end() );
01262 testingEventVector->insert ( testingEventVector->end(), tmpEventVector[Types::kTesting].at(cls).begin(), tmpEventVector[Types::kTesting].at(cls).end() );
01263 }
01264 }
01265
01266
01267
01268
01269
01270
01271
01272
01273 tmpEventVector[Types::kTraining].clear();
01274 tmpEventVector[Types::kTesting].clear();
01275
01276 tmpEventVector[Types::kMaxTreeType].clear();
01277
01278 if (mixMode == "RANDOM") {
01279 Log() << kDEBUG << "shuffling events"<<Endl;
01280
01281
01282
01283
01284 std::random_shuffle( trainingEventVector->begin(), trainingEventVector->end(), rndm );
01285 std::random_shuffle( testingEventVector->begin(), testingEventVector->end(), rndm );
01286
01287
01288
01289 }
01290
01291 Log() << kDEBUG << "trainingEventVector " << trainingEventVector->size() << Endl;
01292 Log() << kDEBUG << "testingEventVector " << testingEventVector->size() << Endl;
01293
01294
01295 DataSet* ds = new DataSet(dsi);
01296
01297 Log() << kINFO << "Create internal training tree" << Endl;
01298 ds->SetEventCollection(trainingEventVector, Types::kTraining );
01299 Log() << kINFO << "Create internal testing tree" << Endl;
01300 ds->SetEventCollection(testingEventVector, Types::kTesting );
01301
01302
01303 return ds;
01304
01305 }
01306
01307
01308
01309
01310 void TMVA::DataSetFactory::RenormEvents( TMVA::DataSetInfo& dsi,
01311 TMVA::EventVectorOfClassesOfTreeType& tmpEventVector,
01312 const TString& normMode )
01313 {
01314
01315
01316
01317
01318
01319
01320
01321 if (normMode == "NONE") {
01322 Log() << kINFO << "No weight renormalisation applied: use original event weights" << Endl;
01323 return;
01324 }
01325
01326
01327
01328 Int_t trainingSize = 0;
01329 Int_t testingSize = 0;
01330
01331 ValuePerClass trainingSumWeightsPerClass( dsi.GetNClasses() );
01332 ValuePerClass testingSumWeightsPerClass( dsi.GetNClasses() );
01333
01334 NumberPerClass trainingSizePerClass( dsi.GetNClasses() );
01335 NumberPerClass testingSizePerClass( dsi.GetNClasses() );
01336
01337 Double_t trainingSumWeights = 0;
01338 Double_t testingSumWeights = 0;
01339
01340 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
01341 trainingSizePerClass.at(cls) = tmpEventVector[Types::kTraining].at(cls).size();
01342 testingSizePerClass.at(cls) = tmpEventVector[Types::kTesting].at(cls).size();
01343
01344 trainingSize += trainingSizePerClass.back();
01345 testingSize += testingSizePerClass.back();
01346
01347
01348
01349
01350
01351
01352
01353
01354
01355
01356
01357
01358
01359 trainingSumWeightsPerClass.at(cls) = std::accumulate( tmpEventVector[Types::kTraining].at(cls).begin(),
01360 tmpEventVector[Types::kTraining].at(cls).end(),
01361 Double_t(0),
01362 compose_binary( std::plus<Double_t>(),
01363 null<Double_t>(),
01364 std::mem_fun(&TMVA::Event::GetOriginalWeight) ) );
01365
01366 testingSumWeightsPerClass.at(cls) = std::accumulate( tmpEventVector[Types::kTesting].at(cls).begin(),
01367 tmpEventVector[Types::kTesting].at(cls).end(),
01368 Double_t(0),
01369 compose_binary( std::plus<Double_t>(),
01370 null<Double_t>(),
01371 std::mem_fun(&TMVA::Event::GetOriginalWeight) ) );
01372
01373
01374 trainingSumWeights += trainingSumWeightsPerClass.at(cls);
01375 testingSumWeights += testingSumWeightsPerClass.at(cls);
01376 }
01377
01378
01379
01380
01381 ValuePerClass renormFactor( dsi.GetNClasses() );
01382
01383 if (normMode == "NUMEVENTS") {
01384 Log() << kINFO << "Weight renormalisation mode: \"NumEvents\": renormalise independently the ..." << Endl;
01385 Log() << kINFO << "... class weights so that Sum[i=1..N_j]{w_i} = N_j, j=0,1,2..." << Endl;
01386 Log() << kINFO << "... (note that N_j is the sum of training and test events)" << Endl;
01387
01388 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
01389 renormFactor.at(cls) = ( (trainingSizePerClass.at(cls) + testingSizePerClass.at(cls))/
01390 (trainingSumWeightsPerClass.at(cls) + testingSumWeightsPerClass.at(cls)) );
01391 }
01392 }
01393 else if (normMode == "EQUALNUMEVENTS") {
01394 Log() << kINFO << "Weight renormalisation mode: \"EqualNumEvents\": renormalise class weights ..." << Endl;
01395 Log() << kINFO << "... so that Sum[i=1..N_j]{w_i} = N_classA, j=classA, classB, ..." << Endl;
01396 Log() << kINFO << "... (note that N_j is the sum of training and test events)" << Endl;
01397
01398 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ) {
01399 renormFactor.at(cls) = Float_t(trainingSizePerClass.at(cls)+testingSizePerClass.at(cls))/
01400 (trainingSumWeightsPerClass.at(cls)+testingSumWeightsPerClass.at(cls));
01401 }
01402
01403 UInt_t referenceClass = 0;
01404 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ) {
01405 if( cls == referenceClass ) continue;
01406 renormFactor.at(cls) *= Float_t(trainingSizePerClass.at(referenceClass)+testingSizePerClass.at(referenceClass) )/
01407 Float_t( trainingSizePerClass.at(cls)+testingSizePerClass.at(cls) );
01408 }
01409 }
01410 else {
01411 Log() << kFATAL << "<PrepareForTrainingAndTesting> Unknown NormMode: " << normMode << Endl;
01412 }
01413
01414
01415
01416 Int_t maxL = dsi.GetClassNameMaxLength();
01417 for (UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls<clsEnd; ++cls) {
01418 Log() << kINFO << "--> Rescale " << setiosflags(ios::left) << std::setw(maxL)
01419 << dsi.GetClassInfo(cls)->GetName() << " event weights by factor: " << renormFactor.at(cls) << Endl;
01420 std::for_each( tmpEventVector[Types::kTraining].at(cls).begin(),
01421 tmpEventVector[Types::kTraining].at(cls).end(),
01422 std::bind2nd(std::mem_fun(&TMVA::Event::ScaleWeight),renormFactor.at(cls)) );
01423 std::for_each( tmpEventVector[Types::kTesting].at(cls).begin(),
01424 tmpEventVector[Types::kTesting].at(cls).end(),
01425 std::bind2nd(std::mem_fun(&TMVA::Event::ScaleWeight),renormFactor.at(cls)) );
01426 }
01427
01428
01429
01430
01431
01432
01433 dsi.SetNormalization( normMode );
01434
01435
01436
01437
01438
01439
01440 Log() << kINFO << "Number of training and testing events after rescaling:" << Endl;
01441 Log() << kINFO << "------------------------------------------------------" << Endl;
01442 trainingSumWeights = 0;
01443 testingSumWeights = 0;
01444 for( UInt_t cls = 0, clsEnd = dsi.GetNClasses(); cls < clsEnd; ++cls ){
01445
01446 trainingSumWeightsPerClass.at(cls) = (std::accumulate( tmpEventVector[Types::kTraining].at(cls).begin(),
01447 tmpEventVector[Types::kTraining].at(cls).end(),
01448 Double_t(0),
01449 compose_binary( std::plus<Double_t>(),
01450 null<Double_t>(),
01451 std::mem_fun(&TMVA::Event::GetOriginalWeight) ) ));
01452
01453 testingSumWeightsPerClass.at(cls) = std::accumulate( tmpEventVector[Types::kTesting].at(cls).begin(),
01454 tmpEventVector[Types::kTesting].at(cls).end(),
01455 Double_t(0),
01456 compose_binary( std::plus<Double_t>(),
01457 null<Double_t>(),
01458 std::mem_fun(&TMVA::Event::GetOriginalWeight) ) );
01459
01460
01461 trainingSumWeights += trainingSumWeightsPerClass.at(cls);
01462 testingSumWeights += testingSumWeightsPerClass.at(cls);
01463
01464
01465 Log() << kINFO << setiosflags(ios::left) << std::setw(maxL)
01466 << dsi.GetClassInfo(cls)->GetName() << " -- "
01467 << "training entries : " << trainingSizePerClass.at(cls)
01468 << " (" << "sum of weights: " << trainingSumWeightsPerClass.at(cls) << ")" << Endl;
01469 Log() << kINFO << setiosflags(ios::left) << std::setw(maxL)
01470 << dsi.GetClassInfo(cls)->GetName() << " -- "
01471 << "testing entries : " << testingSizePerClass.at(cls)
01472 << " (" << "sum of weights: " << testingSumWeightsPerClass.at(cls) << ")" << Endl;
01473 Log() << kINFO << setiosflags(ios::left) << std::setw(maxL)
01474 << dsi.GetClassInfo(cls)->GetName() << " -- "
01475 << "training and testing entries: "
01476 << (trainingSizePerClass.at(cls)+testingSizePerClass.at(cls))
01477 << " (" << "sum of weights: "
01478 << (trainingSumWeightsPerClass.at(cls)+testingSumWeightsPerClass.at(cls)) << ")" << Endl;
01479 }
01480
01481 }
01482
01483
01484