DataSetInfo.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: DataSetInfo.cxx 37097 2010-11-30 12:28:05Z evt $
00002 // Author: Joerg Stelzer, Peter Speckmeier
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : DataSetInfo                                                           *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header for description)                               *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Peter Speckmayer <speckmay@mail.cern.ch> - CERN, Switzerland              *
00015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - DESY, Germany                  *
00016  *                                                                                *
00017  * Copyright (c) 2008:                                                            *
00018  *      CERN, Switzerland                                                         *
00019  *      MPI-K Heidelberg, Germany                                                 *
00020  *      DESY Hamburg, Germany                                                     *
00021  *                                                                                *
00022  * Redistribution and use in source and binary forms, with or without             *
00023  * modification, are permitted according to the terms listed in LICENSE           *
00024  * (http://tmva.sourceforge.net/LICENSE)                                          *
00025  **********************************************************************************/
00026 
00027 #include <vector>
00028 
00029 #include "TEventList.h"
00030 #include "TFile.h"
00031 #include "TH1.h"
00032 #include "TH2.h"
00033 #include "TProfile.h"
00034 #include "TRandom3.h"
00035 #include "TMatrixF.h"
00036 #include "TVectorF.h"
00037 #include "TMath.h"
00038 #include "TROOT.h"
00039 #include "TObjString.h"
00040 
00041 #ifndef ROOT_TMVA_MsgLogger
00042 #include "TMVA/MsgLogger.h"
00043 #endif
00044 #ifndef ROOT_TMVA_Tools
00045 #include "TMVA/Tools.h"
00046 #endif
00047 #ifndef ROOT_TMVA_DataSet
00048 #include "TMVA/DataSet.h"
00049 #endif
00050 #ifndef ROOT_TMVA_DataSetInfo
00051 #include "TMVA/DataSetInfo.h"
00052 #endif
00053 #ifndef ROOT_TMVA_DataSetManager
00054 #include "TMVA/DataSetManager.h"
00055 #endif
00056 #ifndef ROOT_TMVA_Event
00057 #include "TMVA/Event.h"
00058 #endif
00059 
00060 //_______________________________________________________________________
00061 TMVA::DataSetInfo::DataSetInfo(const TString& name) 
00062    : TObject(),
00063      fDataSetManager(NULL),
00064      fName(name),
00065      fDataSet( 0 ),
00066      fNeedsRebuilding( kTRUE ),
00067      fVariables(),
00068      fTargets(),
00069      fSpectators(),
00070      fClasses( 0 ),
00071      fNormalization( "NONE" ),
00072      fSplitOptions(""),
00073      fOwnRootDir(0),
00074      fVerbose( kFALSE ),
00075      fSignalClass(0),
00076      fTargetsForMulticlass(0),
00077      fLogger( new MsgLogger("DataSetInfo", kINFO) )
00078 {
00079    // constructor
00080 
00081 }
00082 
00083 //_______________________________________________________________________
00084 TMVA::DataSetInfo::~DataSetInfo() 
00085 {
00086    // destructor
00087    ClearDataSet();
00088    
00089    for(UInt_t i=0, iEnd = fClasses.size(); i<iEnd; ++i) {
00090       delete fClasses[i];
00091    }
00092 
00093    delete fTargetsForMulticlass;
00094 
00095    delete fLogger;
00096 }
00097 
00098 //_______________________________________________________________________
00099 void TMVA::DataSetInfo::ClearDataSet() const 
00100 {
00101    if(fDataSet!=0) { delete fDataSet; fDataSet=0; }
00102 }
00103 
00104 //_______________________________________________________________________
00105 TMVA::ClassInfo* TMVA::DataSetInfo::AddClass( const TString& className ) 
00106 {
00107 
00108    ClassInfo* theClass = GetClassInfo(className);
00109    if (theClass) return theClass;
00110 
00111    fClasses.push_back( new ClassInfo(className) );
00112    fClasses.back()->SetNumber(fClasses.size()-1);
00113 
00114    Log() << kINFO << "Added class \"" << className << "\"\t with internal class number " 
00115          << fClasses.back()->GetNumber() << Endl;
00116 
00117    if (className == "Signal") fSignalClass = fClasses.size()-1;  // store the signal class index ( for comparison reasons )
00118 
00119    return fClasses.back();
00120 }
00121 
00122 //_______________________________________________________________________
00123 void TMVA::DataSetInfo::SetMsgType( EMsgType t ) const 
00124 {  
00125     fLogger->SetMinType(t);  
00126 } 
00127 
00128 //_______________________________________________________________________
00129 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( const TString& name ) const 
00130 {
00131    for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00132       if ((*it)->GetName() == name) return (*it);
00133    }
00134    return 0;
00135 }
00136 
00137 //_______________________________________________________________________
00138 TMVA::ClassInfo* TMVA::DataSetInfo::GetClassInfo( Int_t cls ) const 
00139 {
00140    try {
00141       return fClasses.at(cls);
00142    }
00143    catch(...) {
00144       return 0;
00145    }
00146 }
00147 
00148 //_______________________________________________________________________
00149 void TMVA::DataSetInfo::PrintClasses() const 
00150 {
00151    for (UInt_t cls = 0; cls < GetNClasses() ; cls++) {
00152       Log() << kINFO << "Class index : " << cls << "  name : " << GetClassInfo(cls)->GetName() << Endl;
00153    }
00154 }
00155 
00156 //_______________________________________________________________________
00157 Bool_t TMVA::DataSetInfo::IsSignal( const TMVA::Event* ev ) const 
00158 {
00159    return (ev->GetClass()  == fSignalClass); 
00160 }
00161 
00162 //_______________________________________________________________________
00163 std::vector<Float_t>*  TMVA::DataSetInfo::GetTargetsForMulticlass( const TMVA::Event* ev ) 
00164 {
00165    if( !fTargetsForMulticlass ) fTargetsForMulticlass = new std::vector<Float_t>( GetNClasses() );
00166 //   fTargetsForMulticlass->resize( GetNClasses() );
00167    fTargetsForMulticlass->assign( GetNClasses(), 0.0 );
00168    fTargetsForMulticlass->at( ev->GetClass() ) = 1.0;
00169    return fTargetsForMulticlass; 
00170 }
00171 
00172 
00173 //_______________________________________________________________________
00174 Bool_t TMVA::DataSetInfo::HasCuts() const 
00175 {
00176    Bool_t hasCuts = kFALSE;
00177    for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00178       if( TString((*it)->GetCut()) != TString("") ) hasCuts = kTRUE;
00179    }
00180    return hasCuts;
00181 }
00182 
00183 //_______________________________________________________________________
00184 const TMatrixD* TMVA::DataSetInfo::CorrelationMatrix( const TString& className ) const 
00185 { 
00186    ClassInfo* ptr = GetClassInfo(className);
00187    return ptr?ptr->GetCorrelationMatrix():0;
00188 }
00189 
00190 //_______________________________________________________________________
00191 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable( const TString& expression, const TString& title, const TString& unit, 
00192                                                     Double_t min, Double_t max, char varType,
00193                                                     Bool_t normalized, void* external )
00194 {
00195    // add a variable (can be a complex expression) to the set of variables used in
00196    // the MV analysis
00197    TString regexpr = expression; // remove possible blanks
00198    regexpr.ReplaceAll(" ", "" );
00199    fVariables.push_back(VariableInfo( regexpr, title, unit, 
00200                                       fVariables.size()+1, varType, external, min, max, normalized ));
00201    fNeedsRebuilding = kTRUE;
00202    return fVariables.back();
00203 }
00204 
00205 //_______________________________________________________________________
00206 TMVA::VariableInfo& TMVA::DataSetInfo::AddVariable( const VariableInfo& varInfo){
00207    // add variable with given VariableInfo
00208    fVariables.push_back(VariableInfo( varInfo ));
00209    fNeedsRebuilding = kTRUE;
00210    return fVariables.back();
00211 }
00212 
00213 //_______________________________________________________________________
00214 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget( const TString& expression, const TString& title, const TString& unit, 
00215                                                   Double_t min, Double_t max, 
00216                                                   Bool_t normalized, void* external )
00217 {
00218    // add a variable (can be a complex expression) to the set of variables used in
00219    // the MV analysis
00220    TString regexpr = expression; // remove possible blanks
00221    regexpr.ReplaceAll(" ", "" );
00222    char type='F';
00223    fTargets.push_back(VariableInfo( regexpr, title, unit, 
00224                                     fTargets.size()+1, type, external, min, max, normalized ));
00225    fNeedsRebuilding = kTRUE;
00226    return fTargets.back();
00227 }
00228 
00229 //_______________________________________________________________________
00230 TMVA::VariableInfo& TMVA::DataSetInfo::AddTarget( const VariableInfo& varInfo){
00231    // add target with given VariableInfo
00232    fTargets.push_back(VariableInfo( varInfo ));
00233    fNeedsRebuilding = kTRUE;
00234    return fTargets.back();
00235 }
00236 
00237 //_______________________________________________________________________
00238 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator( const TString& expression, const TString& title, const TString& unit, 
00239                                                      Double_t min, Double_t max, char type,
00240                                                      Bool_t normalized, void* external )
00241 {
00242    // add a spectator (can be a complex expression) to the set of spectator variables used in
00243    // the MV analysis
00244    TString regexpr = expression; // remove possible blanks
00245    regexpr.ReplaceAll(" ", "" );
00246    fSpectators.push_back(VariableInfo( regexpr, title, unit, 
00247                                        fSpectators.size()+1, type, external, min, max, normalized ));
00248    fNeedsRebuilding = kTRUE;
00249    return fSpectators.back();
00250 }
00251 
00252 //_______________________________________________________________________
00253 TMVA::VariableInfo& TMVA::DataSetInfo::AddSpectator( const VariableInfo& varInfo){
00254    // add spectator with given VariableInfo
00255    fSpectators.push_back(VariableInfo( varInfo ));
00256    fNeedsRebuilding = kTRUE;
00257    return fSpectators.back();
00258 }
00259 
00260 //_______________________________________________________________________
00261 Int_t TMVA::DataSetInfo::FindVarIndex(const TString& var) const
00262 {
00263    // find variable by name
00264    for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) 
00265       if (var == GetVariableInfo(ivar).GetInternalName()) return ivar;
00266    
00267    for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) 
00268       Log() << kINFO  <<  GetVariableInfo(ivar).GetInternalName() << Endl;
00269    
00270    Log() << kFATAL << "<FindVarIndex> Variable \'" << var << "\' not found." << Endl;
00271  
00272    return -1;
00273 }
00274 
00275 //_______________________________________________________________________
00276 void TMVA::DataSetInfo::SetWeightExpression( const TString& expr, const TString& className ) 
00277 {
00278    // set the weight expressions for the classes
00279    // if class name is specified, set only for this class
00280    // if class name is unknown, register new class with this name
00281 
00282    if (className != "") {
00283       TMVA::ClassInfo* ci = AddClass(className);
00284       ci->SetWeight( expr );
00285    } 
00286    else {
00287       // no class name specified, set weight for all classes
00288       if (fClasses.size()==0) {
00289          Log() << kWARNING << "No classes registered yet, cannot specify weight expression!" << Endl;
00290       }
00291       for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00292          (*it)->SetWeight( expr );
00293       }
00294    }
00295 }
00296 
00297 //_______________________________________________________________________
00298 void TMVA::DataSetInfo::SetCorrelationMatrix( const TString& className, TMatrixD* matrix ) 
00299 {
00300    GetClassInfo(className)->SetCorrelationMatrix(matrix); 
00301 }
00302 
00303 //_______________________________________________________________________
00304 void TMVA::DataSetInfo::SetCut( const TCut& cut, const TString& className ) 
00305 {
00306    // set the cut for the classes
00307    if (className == "") {  // if no className has been given set the cut for all the classes
00308       for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00309          (*it)->SetCut( cut );
00310       }
00311    }
00312    else {
00313       TMVA::ClassInfo* ci = AddClass(className);
00314       ci->SetCut( cut );
00315    }
00316 }
00317 
00318 //_______________________________________________________________________
00319 void TMVA::DataSetInfo::AddCut( const TCut& cut, const TString& className ) 
00320 {
00321    // set the cut for the classes
00322    if (className == "") {  // if no className has been given set the cut for all the classes
00323       for (std::vector<ClassInfo*>::iterator it = fClasses.begin(); it < fClasses.end(); it++) {
00324          const TCut& oldCut = (*it)->GetCut(); 
00325          (*it)->SetCut( oldCut+cut );
00326       }
00327    }
00328    else {
00329       TMVA::ClassInfo* ci = AddClass(className);
00330       ci->SetCut( ci->GetCut()+cut );
00331    }
00332 }
00333 
00334 //_______________________________________________________________________
00335 std::vector<TString> TMVA::DataSetInfo::GetListOfVariables() const
00336 {
00337    // returns list of variables
00338    std::vector<TString> vNames;
00339    std::vector<TMVA::VariableInfo>::const_iterator viIt = GetVariableInfos().begin();
00340    for(;viIt != GetVariableInfos().end(); viIt++) vNames.push_back( (*viIt).GetExpression() );
00341 
00342    return vNames;
00343 }
00344 
00345 //_______________________________________________________________________
00346 void TMVA::DataSetInfo::PrintCorrelationMatrix( const TString& className )
00347 { 
00348    // calculates the correlation matrices for signal and background, 
00349    // prints them to standard output, and fills 2D histograms
00350    Log() << kINFO << "Correlation matrix (" << className << "):" << Endl;
00351    gTools().FormattedOutput( *CorrelationMatrix( className ), GetListOfVariables(), Log() );
00352 }
00353 
00354 //_______________________________________________________________________
00355 TH2* TMVA::DataSetInfo::CreateCorrelationMatrixHist( const TMatrixD* m,
00356                                                      const TString&  hName,
00357                                                      const TString&  hTitle ) const
00358 {
00359    if (m==0) return 0;
00360    
00361    const UInt_t nvar = GetNVariables();
00362 
00363    // workaround till the TMatrix templates are comonly used
00364    // this keeps backward compatibility
00365    TMatrixF* tm = new TMatrixF( nvar, nvar );
00366    for (UInt_t ivar=0; ivar<nvar; ivar++) {
00367       for (UInt_t jvar=0; jvar<nvar; jvar++) {
00368          (*tm)(ivar, jvar) = (*m)(ivar,jvar);
00369       }
00370    }  
00371 
00372    TH2F* h2 = new TH2F( *tm );
00373    h2->SetNameTitle( hName, hTitle );
00374 
00375    for (UInt_t ivar=0; ivar<nvar; ivar++) {
00376       h2->GetXaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
00377       h2->GetYaxis()->SetBinLabel( ivar+1, GetVariableInfo(ivar).GetTitle() );
00378    }
00379    
00380    // present in percent, and round off digits
00381    // also, use absolute value of correlation coefficient (ignore sign)
00382    h2->Scale( 100.0  ); 
00383    for (UInt_t ibin=1; ibin<=nvar; ibin++) {
00384       for (UInt_t jbin=1; jbin<=nvar; jbin++) {
00385          h2->SetBinContent( ibin, jbin, Int_t(h2->GetBinContent( ibin, jbin )) );
00386       }
00387    }
00388    
00389    // style settings
00390    const Float_t labelSize = 0.055;
00391    h2->SetStats( 0 );
00392    h2->GetXaxis()->SetLabelSize( labelSize );
00393    h2->GetYaxis()->SetLabelSize( labelSize );
00394    h2->SetMarkerSize( 1.5 );
00395    h2->SetMarkerColor( 0 );
00396    h2->LabelsOption( "d" ); // diagonal labels on x axis
00397    h2->SetLabelOffset( 0.011 );// label offset on x axis
00398    h2->SetMinimum( -100.0 );
00399    h2->SetMaximum( +100.0 );
00400 
00401    // -------------------------------------------------------------------------------------
00402    // just in case one wants to change the position of the color palette axis
00403    // -------------------------------------------------------------------------------------
00404    //     gROOT->SetStyle("Plain");
00405    //     TStyle* gStyle = gROOT->GetStyle( "Plain" );
00406    //     gStyle->SetPalette( 1, 0 );
00407    //     TPaletteAxis* paletteAxis 
00408    //                   = (TPaletteAxis*)h2->GetListOfFunctions()->FindObject( "palette" );
00409    // -------------------------------------------------------------------------------------
00410    
00411    Log() << kDEBUG << "Created correlation matrix as 2D histogram: " << h2->GetName() << Endl;
00412    
00413    return h2;
00414 }
00415 
00416 //_______________________________________________________________________
00417 TMVA::DataSet* TMVA::DataSetInfo::GetDataSet() const 
00418 {
00419    // returns data set
00420    if (fDataSet==0 || fNeedsRebuilding) {
00421       if(fDataSet!=0) ClearDataSet();
00422 //      fDataSet = DataSetManager::Instance().CreateDataSet(GetName()); //DSMTEST replaced by following lines
00423       if( !fDataSetManager )
00424          Log() << kFATAL << "DataSetManager has not been set in DataSetInfo (GetDataSet() )." << Endl;
00425       fDataSet = fDataSetManager->CreateDataSet(GetName());
00426 
00427 
00428 
00429       fNeedsRebuilding = kFALSE;
00430    }
00431    return fDataSet;
00432 }
00433 
00434 //_______________________________________________________________________
00435 UInt_t TMVA::DataSetInfo::GetNSpectators(bool all) const
00436 {
00437    if(all)
00438       return fSpectators.size();
00439    UInt_t nsp(0);
00440    for(std::vector<VariableInfo>::const_iterator spit=fSpectators.begin(); spit!=fSpectators.end(); ++spit) {
00441       if(spit->GetVarType()!='C') nsp++;
00442    }
00443    return nsp;
00444 }
00445 
00446 //_______________________________________________________________________
00447 Int_t TMVA::DataSetInfo::GetClassNameMaxLength() const
00448 {
00449    Int_t maxL = 0;
00450    for (UInt_t cl = 0; cl < GetNClasses(); cl++) {
00451       if (TString(GetClassInfo(cl)->GetName()).Length() > maxL) maxL = TString(GetClassInfo(cl)->GetName()).Length();
00452    }
00453 
00454    return maxL;
00455 }
00456 

Generated on Tue Jul 5 15:16:47 2011 for ROOT_528-00b_version by  doxygen 1.5.1