00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027 #include <vector>
00028
00029 #include "TEventList.h"
00030 #include "TFile.h"
00031 #include "TH1.h"
00032 #include "TH2.h"
00033 #include "TProfile.h"
00034 #include "TRandom3.h"
00035 #include "TMatrixF.h"
00036 #include "TVectorF.h"
00037 #include "TMath.h"
00038 #include "TROOT.h"
00039 #include "TObjString.h"
00040
00041 #ifndef ROOT_TMVA_MsgLogger
00042 #include "TMVA/MsgLogger.h"
00043 #endif
00044 #ifndef ROOT_TMVA_Tools
00045 #include "TMVA/Tools.h"
00046 #endif
00047 #ifndef ROOT_TMVA_DataSet
00048 #include "TMVA/DataSet.h"
00049 #endif
00050 #ifndef ROOT_TMVA_DataSetInfo
00051 #include "TMVA/DataSetInfo.h"
00052 #endif
00053 #ifndef ROOT_TMVA_DataSetManager
00054 #include "TMVA/DataSetManager.h"
00055 #endif
00056 #ifndef ROOT_TMVA_Event
00057 #include "TMVA/Event.h"
00058 #endif
00059
00060
00061 TMVA::DataSetInfo::DataSetInfo(const TString& name)
00062 : TObject(),
00063 fDataSetManager(NULL),
00064 fName(name),
00065 fDataSet( 0 ),
00066 fNeedsRebuilding( kTRUE ),
00067 fVariables(),
00068 fTargets(),
00069 fSpectators(),
00070 fClasses( 0 ),
00071 fNormalization( "NONE" ),
00072 fSplitOptions(""),
00073 fOwnRootDir(0),
00074 fVerbose( kFALSE ),
00075 fSignalClass(0),
00076 fTargetsForMulticlass(0),
00077 fLogger( new MsgLogger("DataSetInfo", kINFO) )
00078 {
00079
00080
00081 }
00082
00083
00084 TMVA::DataSetInfo::~DataSetInfo()
00085 {
00086
00087 ClearDataSet();
00088
00089 for(UInt_t i=0, iEnd = fClasses.size(); i<iEnd; ++i) {
00090 delete fClasses[i];
00091 }
00092
00093 delete fTargetsForMulticlass;
00094
00095 delete fLogger;
00096 }
00097
00098
00099 void TMVA::DataSetInfo::ClearDataSet() const
00100 {
00101 if(fDataSet!=0) { delete fDataSet; fDataSet=0; }
00102 }
00103
00104
00105 TMVA::ClassInfo* TMVA::DataSetInfo::AddClass( const TString& className )
00106 {
00107
00108 ClassInfo* theClass = GetClassInfo(className);
00109 if (theClass) return theClass;
00110
00111 fClasses.push_back( new ClassInfo(className) );
00112 fClasses.back()->SetNumber(fClasses.size()-1);
00113
00114 Log() << kINFO << "Added class \"" << className << "\"\t with internal class number "
00115 << fClasses.back()->GetNumber() << Endl;
00116
00117 if (className == "Signal") fSignalClass = fClasses.size()-1;
00118
00119 return fClasses.back();
00120 }
00121
00122
00123 void TMVA::DataSetInfo::SetMsgType( EMsgType t ) const
00124 {
00125 fLogger->SetMinType(t);
00126 }
00127
00128
00129 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( const TString& name ) const
00130 {
00131 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00132 if ((*it)->GetName() == name) return (*it);
00133 }
00134 return 0;
00135 }
00136
00137
00138 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( Int_t cls ) const
00139 {
00140 try {
00141 return fClasses.at(cls);
00142 }
00143 catch(...) {
00144 return 0;
00145 }
00146 }
00147
00148
00149 void TMVA::DataSetInfo::PrintClasses() const
00150 {
00151 for (UInt_t cls = 0; cls < GetNClasses() ; cls++) {
00152 Log() << kINFO << "Class index : " << cls << " name : " << GetClassInfo(cls)->GetName() << Endl;
00153 }
00154 }
00155
00156
00157 Bool_t TMVA::DataSetInfo::IsSignal( const TMVA::Event* ev ) const
00158 {
00159 return (ev->GetClass() == fSignalClass);
00160 }
00161
00162
00163 std::vector<Float_t>* TMVA::DataSetInfo::GetTargetsForMulticlass( const TMVA::Event* ev )
00164 {
00165 if( !fTargetsForMulticlass ) fTargetsForMulticlass = new std::vector<Float_t>( GetNClasses() );
00166
00167 fTargetsForMulticlass->assign( GetNClasses(), 0.0 );
00168 fTargetsForMulticlass->at( ev->GetClass() ) = 1.0;
00169 return fTargetsForMulticlass;
00170 }
00171
00172
00173
00174 Bool_t TMVA::DataSetInfo::HasCuts() const
00175 {
00176 Bool_t hasCuts = kFALSE;
00177 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00178 if( TString((*it)->GetCut()) != TString("") ) hasCuts = kTRUE;
00179 }
00180 return hasCuts;
00181 }
00182
00183
00184 const TMatrixD* TMVA::DataSetInfo::CorrelationMatrix( const TString& className ) const
00185 {
00186 ClassInfo* ptr = GetClassInfo(className);
00187 return ptr?ptr->GetCorrelationMatrix():0;
00188 }
00189
00190
00191 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable( const TString& expression, const TString& title, const TString& unit,
00192 Double_t min, Double_t max, char varType,
00193 Bool_t normalized, void* external )
00194 {
00195
00196
00197 TString regexpr = expression;
00198 regexpr.ReplaceAll(" ", "" );
00199 fVariables.push_back(VariableInfo( regexpr, title, unit,
00200 fVariables.size()+1, varType, external, min, max, normalized ));
00201 fNeedsRebuilding = kTRUE;
00202 return fVariables.back();
00203 }
00204
00205
00206 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable( const VariableInfo& varInfo){
00207
00208 fVariables.push_back(VariableInfo( varInfo ));
00209 fNeedsRebuilding = kTRUE;
00210 return fVariables.back();
00211 }
00212
00213
00214 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget( const TString& expression, const TString& title, const TString& unit,
00215 Double_t min, Double_t max,
00216 Bool_t normalized, void* external )
00217 {
00218
00219
00220 TString regexpr = expression;
00221 regexpr.ReplaceAll(" ", "" );
00222 char type='F';
00223 fTargets.push_back(VariableInfo( regexpr, title, unit,
00224 fTargets.size()+1, type, external, min, max, normalized ));
00225 fNeedsRebuilding = kTRUE;
00226 return fTargets.back();
00227 }
00228
00229
00230 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget( const VariableInfo& varInfo){
00231
00232 fTargets.push_back(VariableInfo( varInfo ));
00233 fNeedsRebuilding = kTRUE;
00234 return fTargets.back();
00235 }
00236
00237
00238 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator( const TString& expression, const TString& title, const TString& unit,
00239 Double_t min, Double_t max, char type,
00240 Bool_t normalized, void* external )
00241 {
00242
00243
00244 TString regexpr = expression;
00245 regexpr.ReplaceAll(" ", "" );
00246 fSpectators.push_back(VariableInfo( regexpr, title, unit,
00247 fSpectators.size()+1, type, external, min, max, normalized ));
00248 fNeedsRebuilding = kTRUE;
00249 return fSpectators.back();
00250 }
00251
00252
00253 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator( const VariableInfo& varInfo){
00254
00255 fSpectators.push_back(VariableInfo( varInfo ));
00256 fNeedsRebuilding = kTRUE;
00257 return fSpectators.back();
00258 }
00259
00260
00261 Int_t TMVA::DataSetInfo::FindVarIndex(const TString& var) const
00262 {
00263
00264 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
00265 if (var == GetVariableInfo(ivar).GetInternalName()) return ivar;
00266
00267 for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
00268 Log() << kINFO << GetVariableInfo(ivar).GetInternalName() << Endl;
00269
00270 Log() << kFATAL << "<FindVarIndex> Variable \'" << var << "\' not found." << Endl;
00271
00272 return -1;
00273 }
00274
00275
00276 void TMVA::DataSetInfo::SetWeightExpression( const TString& expr, const TString& className )
00277 {
00278
00279
00280
00281
00282 if (className != "") {
00283 TMVA::ClassInfo* ci = AddClass(className);
00284 ci->SetWeight( expr );
00285 }
00286 else {
00287
00288 if (fClasses.size()==0) {
00289 Log() << kWARNING << "No classes registered yet, cannot specify weight expression!" << Endl;
00290 }
00291 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00292 (*it)->SetWeight( expr );
00293 }
00294 }
00295 }
00296
00297
00298 void TMVA::DataSetInfo::SetCorrelationMatrix( const TString& className, TMatrixD* matrix )
00299 {
00300 GetClassInfo(className)->SetCorrelationMatrix(matrix);
00301 }
00302
00303
00304 void TMVA::DataSetInfo::SetCut( const TCut& cut, const TString& className )
00305 {
00306
00307 if (className == "") {
00308 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00309 (*it)->SetCut( cut );
00310 }
00311 }
00312 else {
00313 TMVA::ClassInfo* ci = AddClass(className);
00314 ci->SetCut( cut );
00315 }
00316 }
00317
00318
00319 void TMVA::DataSetInfo::AddCut( const TCut& cut, const TString& className )
00320 {
00321
00322 if (className == "") {
00323 for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00324 const TCut& oldCut = (*it)->GetCut();
00325 (*it)->SetCut( oldCut+cut );
00326 }
00327 }
00328 else {
00329 TMVA::ClassInfo* ci = AddClass(className);
00330 ci->SetCut( ci->GetCut()+cut );
00331 }
00332 }
00333
00334
00335 std::vector<TString> TMVA::DataSetInfo::GetListOfVariables() const
00336 {
00337
00338 std::vector<TString> vNames;
00339 std::vector<TMVA::VariableInfo>::const_iterator viIt = GetVariableInfos().begin();
00340 for(;viIt != GetVariableInfos().end(); viIt++) vNames.push_back( (*viIt).GetExpression() );
00341
00342 return vNames;
00343 }
00344
00345
00346 void TMVA::DataSetInfo::PrintCorrelationMatrix( const TString& className )
00347 {
00348
00349
00350 Log() << kINFO << "Correlation matrix (" << className << "):" << Endl;
00351 gTools().FormattedOutput( *CorrelationMatrix( className ), GetListOfVariables(), Log() );
00352 }
00353
00354
00355 TH2* TMVA::DataSetInfo::CreateCorrelationMatrixHist( const TMatrixD* m,
00356 const TString& hName,
00357 const TString& hTitle ) const
00358 {
00359 if (m==0) return 0;
00360
00361 const UInt_t nvar = GetNVariables();
00362
00363
00364
00365 TMatrixF* tm = new TMatrixF( nvar, nvar );
00366 for (UInt_t ivar=0; ivar<nvar; ivar++) {
00367 for (UInt_t jvar=0; jvar<nvar; jvar++) {
00368 (*tm)(ivar, jvar) = (*m)(ivar,jvar);
00369 }
00370 }
00371
00372 TH2F* h2 = new TH2F( *tm );
00373 h2->SetNameTitle( hName, hTitle );
00374
00375 for (UInt_t ivar=0; ivar<nvar; ivar++) {
00376 h2->GetXaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
00377 h2->GetYaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
00378 }
00379
00380
00381
00382 h2->Scale( 100.0 );
00383 for (UInt_t ibin=1; ibin<=nvar; ibin++) {
00384 for (UInt_t jbin=1; jbin<=nvar; jbin++) {
00385 h2->SetBinContent( ibin, jbin, Int_t(h2->GetBinContent( ibin, jbin )) );
00386 }
00387 }
00388
00389
00390 const Float_t labelSize = 0.055;
00391 h2->SetStats( 0 );
00392 h2->GetXaxis()->SetLabelSize( labelSize );
00393 h2->GetYaxis()->SetLabelSize( labelSize );
00394 h2->SetMarkerSize( 1.5 );
00395 h2->SetMarkerColor( 0 );
00396 h2->LabelsOption( "d" );
00397 h2->SetLabelOffset( 0.011 );
00398 h2->SetMinimum( -100.0 );
00399 h2->SetMaximum( +100.0 );
00400
00401
00402
00403
00404
00405
00406
00407
00408
00409
00410
00411 Log() << kDEBUG << "Created correlation matrix as 2D histogram: " << h2->GetName() << Endl;
00412
00413 return h2;
00414 }
00415
00416
00417 TMVA::DataSet* TMVA::DataSetInfo::GetDataSet() const
00418 {
00419
00420 if (fDataSet==0 || fNeedsRebuilding) {
00421 if(fDataSet!=0) ClearDataSet();
00422
00423 if( !fDataSetManager )
00424 Log() << kFATAL << "DataSetManager has not been set in DataSetInfo (GetDataSet() )." << Endl;
00425 fDataSet = fDataSetManager->CreateDataSet(GetName());
00426
00427
00428
00429 fNeedsRebuilding = kFALSE;
00430 }
00431 return fDataSet;
00432 }
00433
00434
00435 UInt_t TMVA::DataSetInfo::GetNSpectators(bool all) const
00436 {
00437 if(all)
00438 return fSpectators.size();
00439 UInt_t nsp(0);
00440 for(std::vector<VariableInfo>::const_iterator spit=fSpectators.begin(); spit!=fSpectators.end(); ++spit) {
00441 if(spit->GetVarType()!='C') nsp++;
00442 }
00443 return nsp;
00444 }
00445
00446
00447 Int_t TMVA::DataSetInfo::GetClassNameMaxLength() const
00448 {
00449 Int_t maxL = 0;
00450 for (UInt_t cl = 0; cl < GetNClasses(); cl++) {
00451 if (TString(GetClassInfo(cl)->GetName()).Length() > maxL) maxL = TString(GetClassInfo(cl)->GetName()).Length();
00452 }
00453
00454 return maxL;
00455 }
00456