TransformationHandler.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: TransformationHandler.cxx 36966 2010-11-26 09:50:13Z evt $
00002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss, Eckhard von Toerne
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : TransformationHandler                                                 *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header for description)                               *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Peter Speckmayer <speckmay@mail.cern.ch>  - CERN, Switzerland             *
00016  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00017  *      Eckhard v. Toerne  <evt@uni-bonn.de>     - U of Bonn, Germany          *  
00018  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00019  *                                                                                *
00020  * Copyright (c) 2008:                                                            *
00021  *      CERN, Switzerland                                                         *
00022  *      MPI-K Heidelberg, Germany                                                 *
00023  *      U. of Bonn, Germany                                                       *
00024  *                                                                                *
00025  * Redistribution and use in source and binary forms, with or without             *
00026  * modification, are permitted according to the terms listed in LICENSE           *
00027  * (http://tmva.sourceforge.net/LICENSE)                                          *
00028  **********************************************************************************/
00029 
00030 #include <vector>
00031 #include <iomanip>
00032 
00033 #include "TMath.h"
00034 #include "TH1.h"
00035 #include "TH2.h"
00036 #include "TAxis.h"
00037 #include "TProfile.h"
00038 
00039 #ifndef ROOT_TMVA_Config
00040 #include "TMVA/Config.h"
00041 #endif
00042 #ifndef ROOT_TMVA_DataSet
00043 #include "TMVA/DataSet.h"
00044 #endif
00045 #ifndef ROOT_TMVA_Event
00046 #include "TMVA/Event.h"
00047 #endif
00048 #ifndef ROOT_TMVA_MsgLogger
00049 #include "TMVA/MsgLogger.h"
00050 #endif
00051 #ifndef ROOT_TMVA_Ranking
00052 #include "TMVA/Ranking.h"
00053 #endif
00054 #ifndef ROOT_TMVA_Tools
00055 #include "TMVA/Tools.h"
00056 #endif
00057 #ifndef ROOT_TMVA_TransformationHandler
00058 #include "TMVA/TransformationHandler.h"
00059 #endif
00060 #ifndef ROOT_TMVA_VariableTransformBase
00061 #include "TMVA/VariableTransformBase.h"
00062 #endif
00063 #include "TMVA/VariableIdentityTransform.h"
00064 #include "TMVA/VariableDecorrTransform.h"
00065 #include "TMVA/VariablePCATransform.h"
00066 #include "TMVA/VariableGaussTransform.h"
00067 #include "TMVA/VariableNormalizeTransform.h"
00068 
00069 //_______________________________________________________________________
00070 TMVA::TransformationHandler::TransformationHandler( DataSetInfo& dsi, const TString& callerName ) 
00071    : fDataSetInfo(dsi),
00072      fRootBaseDir(0),
00073      fCallerName (callerName),
00074      fLogger     ( new MsgLogger(TString("TFHandler_" + callerName).Data(), kINFO) )
00075 {
00076    // constructor
00077 
00078    // produce one entry for each class and one entry for all classes. If there is only one class, 
00079    // produce only one entry
00080    fNumC = (dsi.GetNClasses()<= 1) ? 1 : dsi.GetNClasses()+1;
00081 
00082    fVariableStats.resize( fNumC );
00083    for (Int_t i=0; i<fNumC; i++ ) fVariableStats.at(i).resize(dsi.GetNVariables() + dsi.GetNTargets());
00084 }
00085 
00086 //_______________________________________________________________________
00087 TMVA::TransformationHandler::~TransformationHandler() 
00088 {
00089    // destructor
00090    std::vector<Ranking*>::const_iterator it = fRanking.begin();
00091    for (; it != fRanking.end(); it++) delete *it;
00092 
00093    fTransformations.SetOwner();
00094    delete fLogger;
00095 }
00096 
00097 //_______________________________________________________________________
00098 void TMVA::TransformationHandler::SetCallerName( const TString& name ) 
00099 { 
00100    fCallerName = name; 
00101    fLogger->SetSource( TString("TFHandler_" + fCallerName).Data() );
00102 }
00103 
00104 //_______________________________________________________________________
00105 TMVA::VariableTransformBase* TMVA::TransformationHandler::AddTransformation( VariableTransformBase *trf, Int_t cls ) 
00106 {
00107    TString tfname = trf->Log().GetName();
00108    trf->Log().SetSource(TString(fCallerName+"_"+tfname+"_TF").Data());
00109    fTransformations.Add(trf);
00110    fTransformationsReferenceClasses.push_back( cls );
00111    return trf;
00112 }
00113 
00114 //_______________________________________________________________________
00115 void TMVA::TransformationHandler::AddStats( Int_t k, UInt_t ivar, Double_t mean, Double_t rms, Double_t min, Double_t max ) 
00116 {
00117    if (rms <= 0) {
00118       Log() << kWARNING << "Variable \"" << Variable(ivar).GetExpression() 
00119             << "\" has zero or negative RMS^2 " 
00120             << "==> set to zero. Please check the variable content" << Endl;
00121       rms = 0;
00122    }
00123 
00124    VariableStat stat; stat.fMean = mean; stat.fRMS = rms; stat.fMin = min; stat.fMax = max;
00125    fVariableStats.at(k).at(ivar) = stat;
00126 }
00127 
00128 //_______________________________________________________________________
00129 void TMVA::TransformationHandler::SetTransformationReferenceClass( Int_t cls ) 
00130 {
00131    // overrides the setting for all classes! (this is put in basically for the likelihood-method)
00132    // be careful with the usage this method
00133    for (UInt_t i = 0; i < fTransformationsReferenceClasses.size(); i++) {
00134       fTransformationsReferenceClasses.at( i ) = cls;
00135    }
00136 }
00137 
00138 //_______________________________________________________________________
00139 const TMVA::Event* TMVA::TransformationHandler::Transform( const Event* ev ) const 
00140 {
00141    // the transformation
00142 
00143    TListIter trIt(&fTransformations);
00144    std::vector<Int_t>::const_iterator rClsIt = fTransformationsReferenceClasses.begin();
00145    const Event* trEv = ev;
00146    while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) {
00147       trEv = trf->Transform(trEv, (*rClsIt) );
00148       rClsIt++;
00149    }
00150    return trEv;
00151 }
00152 
00153 //_______________________________________________________________________
00154 const TMVA::Event* TMVA::TransformationHandler::InverseTransform( const Event* ev ) const 
00155 {
00156    // the inverse transformation
00157    TListIter trIt(&fTransformations);
00158    std::vector< Int_t >::const_iterator rClsIt = fTransformationsReferenceClasses.begin();
00159    const Event* trEv = ev;
00160    while (VariableTransformBase *trf = (VariableTransformBase*) trIt() ) {
00161       if (trf->IsCreated()) trEv = trf->InverseTransform(ev, (*rClsIt) );
00162       else break;
00163       rClsIt++;
00164    }
00165    return trEv;
00166 }
00167 
00168 //_______________________________________________________________________
00169 std::vector<TMVA::Event*>* TMVA::TransformationHandler::CalcTransformations( const std::vector<Event*>& events, 
00170                                                                              Bool_t createNewVector ) 
00171 {
00172    // computation of transformation
00173    std::vector<Event*>* tmpEvents = const_cast<std::vector<Event*>*>(&events);
00174    Bool_t replaceColl = kFALSE; // first let TransformCollection create a new vector
00175 
00176    TListIter trIt(&fTransformations);
00177    std::vector< Int_t >::iterator rClsIt = fTransformationsReferenceClasses.begin();
00178    while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) {
00179       if (trf->PrepareTransformation(*tmpEvents)) {
00180          tmpEvents = TransformCollection(trf, (*rClsIt), tmpEvents, replaceColl);
00181          // we now created a new vector, so the next transformations replace the 
00182          // events by their transformed versions
00183          replaceColl = kTRUE;  
00184          rClsIt++;
00185       }
00186    }
00187 
00188    CalcStats(*tmpEvents);
00189 
00190    // plot the variables once in this transformation
00191    PlotVariables(*tmpEvents);
00192 
00193    if (!createNewVector) {  // if we don't want that newly created event vector to persist, then delete it
00194       if (replaceColl) {    
00195          for ( UInt_t ievt = 0; ievt<tmpEvents->size(); ievt++)
00196             delete (*tmpEvents)[ievt];
00197          delete tmpEvents;
00198       }
00199       return 0;
00200    }
00201    return tmpEvents; // give back the newly created event collection (containing the transformed events)
00202 }
00203 
00204 //_______________________________________________________________________
00205 std::vector<TMVA::Event*>* TMVA::TransformationHandler::TransformCollection( VariableTransformBase* trf,
00206                                                                              Int_t cls,
00207                                                                              std::vector<TMVA::Event*>* events,
00208                                                                              Bool_t replace ) const 
00209 {
00210    // a collection of transformations
00211    std::vector<TMVA::Event*>* tmpEvents = 0;
00212 
00213    if (replace) {   // the events should be replaced by their transformed versions
00214       tmpEvents = events;
00215    } 
00216    else {           // a new event vector is created
00217       tmpEvents = new std::vector<TMVA::Event*>(events->size());
00218    }
00219    for (UInt_t ievt = 0; ievt<events->size(); ievt++) {  // loop through all events
00220       if (replace) {  // and replace the event by its transformed version
00221          *(*tmpEvents)[ievt] = *trf->Transform((*events)[ievt],cls);
00222       } 
00223       else {         // and create a new event which is the transformed version of the old event
00224          (*tmpEvents)[ievt] = new Event(*trf->Transform((*events)[ievt],cls));
00225       }
00226    }
00227    return tmpEvents;
00228 }
00229 
00230 //_______________________________________________________________________
00231 void TMVA::TransformationHandler::CalcStats( const std::vector<Event*>& events )
00232 {
00233 
00234    // method to calculate minimum, maximum, mean, and RMS for all
00235    // variables used in the MVA
00236 
00237    UInt_t nevts = events.size();
00238 
00239    if (nevts==0)
00240       Log() << kFATAL << "No events available to find min, max, mean and rms" << Endl;
00241 
00242    // if transformation has not been succeeded, the tree may be empty
00243    const UInt_t nvar = events[0]->GetNVariables();
00244    const UInt_t ntgt = events[0]->GetNTargets();
00245 
00246    Double_t  *sumOfWeights = new Double_t[fNumC];
00247    Double_t* *x2           = new Double_t*[fNumC];
00248    Double_t* *x0           = new Double_t*[fNumC];
00249    Double_t* *varMin       = new Double_t*[fNumC];
00250    Double_t* *varMax       = new Double_t*[fNumC];
00251    
00252    for (Int_t cls=0; cls<fNumC; cls++) {
00253       sumOfWeights[cls]=0;
00254       x2[cls]     = new Double_t[nvar+ntgt];
00255       x0[cls]     = new Double_t[nvar+ntgt];
00256       varMin[cls] = new Double_t[nvar+ntgt];
00257       varMax[cls] = new Double_t[nvar+ntgt];
00258       for (UInt_t ivar=0; ivar<nvar+ntgt; ivar++) {
00259          x0[cls][ivar] = x2[cls][ivar] = 0;
00260          varMin[cls][ivar] = DBL_MAX;
00261          varMax[cls][ivar] = -DBL_MAX;
00262       }
00263    }
00264 
00265    for (UInt_t ievt=0; ievt<nevts; ievt++) {
00266       Event* ev  = events[ievt];
00267       Int_t  cls = ev->GetClass();
00268 
00269       Double_t weight = ev->GetWeight();
00270       sumOfWeights[cls] += weight;
00271       if (fNumC > 1 ) sumOfWeights[fNumC-1] += weight; // if more than one class, store values for all classes
00272       for (UInt_t var_tgt = 0; var_tgt < 2; var_tgt++ ){ // first for variables, then for targets
00273          UInt_t nloop = ( var_tgt==0?nvar:ntgt );
00274          for (UInt_t ivar=0; ivar<nloop; ivar++) {
00275             Double_t x = ( var_tgt==0?ev->GetValue(ivar):ev->GetTarget(ivar) );
00276 
00277             if (x < varMin[cls][(var_tgt*nvar)+ivar]) varMin[cls][(var_tgt*nvar)+ivar]= x;
00278             if (x > varMax[cls][(var_tgt*nvar)+ivar]) varMax[cls][(var_tgt*nvar)+ivar]= x;
00279 
00280             x0[cls][(var_tgt*nvar)+ivar] += x*weight;
00281             x2[cls][(var_tgt*nvar)+ivar] += x*x*weight;
00282 
00283             if (fNumC > 1) {
00284                if (x < varMin[fNumC-1][(var_tgt*nvar)+ivar]) varMin[fNumC-1][(var_tgt*nvar)+ivar]= x;
00285                if (x > varMax[fNumC-1][(var_tgt*nvar)+ivar]) varMax[fNumC-1][(var_tgt*nvar)+ivar]= x;
00286 
00287                x0[fNumC-1][(var_tgt*nvar)+ivar] += x*weight;
00288                x2[fNumC-1][(var_tgt*nvar)+ivar] += x*x*weight;
00289             }
00290          }
00291       }
00292    }
00293 
00294 
00295    // set Mean and RMS
00296    for (UInt_t var_tgt = 0; var_tgt < 2; var_tgt++ ){ // first for variables, then for targets
00297       UInt_t nloop = ( var_tgt==0?nvar:ntgt );
00298       for (UInt_t ivar=0; ivar<nloop; ivar++) {
00299          for (Int_t cls = 0; cls < fNumC; cls++) {
00300             Double_t mean = x0[cls][(var_tgt*nvar)+ivar]/sumOfWeights[cls];
00301             Double_t rms = TMath::Sqrt( x2[cls][(var_tgt*nvar)+ivar]/sumOfWeights[cls] - mean*mean); 
00302             AddStats(cls, (var_tgt*nvar)+ivar, mean, rms, varMin[cls][(var_tgt*nvar)+ivar], varMax[cls][(var_tgt*nvar)+ivar]);
00303          }
00304       }
00305    }
00306 
00307    // ------ pretty output of basic statistics -------------------------------
00308    // find maximum length in V (and column title)
00309    UInt_t maxL = 8, maxV = 0;
00310    std::vector<UInt_t> vLengths;
00311    for (UInt_t ivar=0; ivar<nvar+ntgt; ivar++) {
00312       if( ivar < nvar )
00313          maxL = TMath::Max( (UInt_t)Variable(ivar).GetLabel().Length(), maxL );
00314       else
00315          maxL = TMath::Max( (UInt_t)Target(ivar-nvar).GetLabel().Length(), maxL );
00316    }
00317    maxV = maxL + 2;
00318    // full column length
00319    UInt_t clen = maxL + 4*maxV + 11;
00320    for (UInt_t i=0; i<clen; i++) Log() << "-";
00321    Log() << Endl;
00322    // full column length
00323    Log() << std::setw(maxL) << "Variable";
00324    Log() << "  " << std::setw(maxV) << "Mean";
00325    Log() << " " << std::setw(maxV) << "RMS";
00326    Log() << "   " << std::setw(maxV) << "[        Min ";
00327    Log() << "  " << std::setw(maxV) << "    Max ]" << Endl;;
00328    for (UInt_t i=0; i<clen; i++) Log() << "-";
00329    Log() << Endl;
00330 
00331    // the numbers
00332    TString format = "%#11.5g";
00333    for (UInt_t ivar=0; ivar<nvar+ntgt; ivar++) {
00334       if( ivar < nvar )
00335          Log() << std::setw(maxL) << Variable(ivar).GetLabel() << ":";
00336       else
00337          Log() << std::setw(maxL) << Target(ivar-nvar).GetLabel() << ":";
00338       Log() << std::setw(maxV) << Form( format.Data(), GetMean(ivar) );
00339       Log() << std::setw(maxV) << Form( format.Data(), GetRMS(ivar) );
00340       Log() << "   [" << std::setw(maxV) << Form( format.Data(), GetMin(ivar) );
00341       Log() << std::setw(maxV) << Form( format.Data(), GetMax(ivar) ) << " ]";
00342       Log() << Endl;
00343    }
00344    for (UInt_t i=0; i<clen; i++) Log() << "-";
00345    Log() << Endl;
00346    // ------------------------------------------------------------------------
00347    
00348    delete[] sumOfWeights;
00349    for (Int_t cls=0; cls<fNumC; cls++) {
00350       delete [] x2[cls];
00351       delete [] x0[cls];
00352       delete [] varMin[cls];
00353       delete [] varMax[cls];
00354    }
00355    delete [] x2;
00356    delete [] x0;
00357    delete [] varMin;
00358    delete [] varMax;
00359 }
00360 
00361 //_______________________________________________________________________
00362 void TMVA::TransformationHandler::MakeFunction( std::ostream& fout, const TString& fncName, Int_t part ) const 
00363 {
00364    // create transformation function
00365    TListIter trIt(&fTransformations);
00366    std::vector< Int_t >::const_iterator rClsIt = fTransformationsReferenceClasses.begin();
00367    UInt_t trCounter=1;
00368    while (VariableTransformBase *trf = (VariableTransformBase*) trIt() ) {
00369       trf->MakeFunction(fout, fncName, part, trCounter++, (*rClsIt) );
00370       rClsIt++;
00371    }
00372    if (part==1) {
00373       for (Int_t i=0; i<fTransformations.GetSize(); i++) {
00374          fout << "   void InitTransform_"<<i+1<<"();" << std::endl;
00375          fout << "   void Transform_"<<i+1<<"( std::vector<double> & iv, int sigOrBgd ) const;" << std::endl;
00376       }
00377    }
00378    if (part==2) {
00379       fout << std::endl;
00380       fout << "//_______________________________________________________________________" << std::endl;
00381       fout << "inline void " << fncName << "::InitTransform()" << std::endl;
00382       fout << "{" << std::endl;
00383       for (Int_t i=0; i<fTransformations.GetSize(); i++)
00384          fout << "   InitTransform_"<<i+1<<"();" << std::endl;
00385       fout << "}" << std::endl;
00386       fout << std::endl;
00387       fout << "//_______________________________________________________________________" << std::endl;
00388       fout << "inline void " << fncName << "::Transform( std::vector<double>& iv, int sigOrBgd ) const" << std::endl;
00389       fout << "{" << std::endl;
00390       for (Int_t i=0; i<fTransformations.GetSize(); i++)
00391          fout << "   Transform_"<<i+1<<"( iv, sigOrBgd );" << std::endl;
00392 
00393       fout << "}" << std::endl;
00394    }
00395 }
00396 
00397 //_______________________________________________________________________
00398 TString TMVA::TransformationHandler::GetName() const
00399 {
00400    // return transformation name
00401    TString name("Id");
00402    TListIter trIt(&fTransformations);
00403    VariableTransformBase *trf;
00404    if ((trf = (VariableTransformBase*) trIt())) {
00405       name = TString(trf->GetShortName());
00406       while ((trf = (VariableTransformBase*) trIt())) name += "_" + TString(trf->GetShortName());
00407    }
00408    return name;
00409 }
00410 
00411 //_______________________________________________________________________
00412 TString TMVA::TransformationHandler::GetVariableAxisTitle( const VariableInfo& info ) const
00413 {
00414    // incorporates transformation type into title axis (usually for histograms)
00415    TString xtit = info.GetTitle();
00416    // indicate transformation, but not in case of single identity transform
00417    if (fTransformations.GetSize() >= 1) {
00418       if (fTransformations.GetSize() > 1 ||
00419           ((VariableTransformBase*)GetTransformationList().Last())->GetVariableTransform() != Types::kIdentity) {
00420          xtit += " (" + GetName() + ")";
00421       }
00422    }
00423    return xtit;
00424 }
00425 
00426 //_______________________________________________________________________
00427 void TMVA::TransformationHandler::PlotVariables( const std::vector<Event*>& events, TDirectory* theDirectory )
00428 {
00429    // create histograms from the input variables
00430    // - histograms for all input variables
00431    // - scatter plots for all pairs of input variables
00432 
00433 
00434    if (fRootBaseDir==0 && theDirectory == 0) return;
00435 
00436    // extension for transformation type
00437    TString transfType = "";
00438    if (theDirectory == 0) {
00439       transfType += "_";
00440       transfType += GetName();
00441    }
00442 
00443    const UInt_t nvar = fDataSetInfo.GetNVariables();
00444    const UInt_t ntgt = fDataSetInfo.GetNTargets();
00445    const Int_t  ncls = fDataSetInfo.GetNClasses();
00446 
00447    // Create all histograms
00448    // do both, scatter and profile plots
00449    std::vector<std::vector<TH1*> > hVars( ncls );  // histograms for variables
00450    std::vector<std::vector<std::vector<TH2F*> > >     mycorr( ncls ); // histograms for correlations
00451    std::vector<std::vector<std::vector<TProfile*> > > myprof( ncls ); // histograms for profiles
00452 
00453    for (Int_t cls = 0; cls < ncls; cls++) {
00454       hVars.at(cls).resize ( nvar+ntgt );
00455       hVars.at(cls).assign ( nvar+ntgt, 0 ); // fill with zeros
00456       mycorr.at(cls).resize( nvar+ntgt );
00457       myprof.at(cls).resize( nvar+ntgt );
00458       for (UInt_t ivar=0; ivar < nvar+ntgt; ivar++) {
00459          mycorr.at(cls).at(ivar).resize( nvar+ntgt );
00460          myprof.at(cls).at(ivar).resize( nvar+ntgt );
00461          mycorr.at(cls).at(ivar).assign( nvar+ntgt, 0 ); // fill with zeros
00462          myprof.at(cls).at(ivar).assign( nvar+ntgt, 0 ); // fill with zeros
00463       }
00464    }
00465 
00466    // if there are too many input variables, the creation of correlations plots blows up
00467    // memory and basically kills the TMVA execution
00468    // --> avoid above critical number (which can be user defined)
00469    if (nvar+ntgt > (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
00470       Int_t nhists = (nvar+ntgt)*(nvar+ntgt - 1)/2;
00471       Log() << kINFO << gTools().Color("dgreen") << Endl;
00472       Log() << kINFO << "<PlotVariables> Will not produce scatter plots ==> " << Endl;
00473       Log() << kINFO
00474             << "|  The number of " << nvar << " input variables and " << ntgt << " target values would require " 
00475             << nhists << " two-dimensional" << Endl;
00476       Log() << kINFO
00477             << "|  histograms, which would occupy the computer's memory. Note that this" << Endl;
00478       Log() << kINFO
00479             << "|  suppression does not have any consequences for your analysis, other" << Endl;
00480       Log() << kINFO
00481             << "|  than not disposing of these scatter plots. You can modify the maximum" << Endl;
00482       Log() << kINFO
00483             << "|  number of input variables allowed to generate scatter plots in your" << Endl; 
00484       Log() << "|  script via the command line:" << Endl;
00485       Log() << kINFO
00486             << "|  \"(TMVA::gConfig().GetVariablePlotting()).fMaxNumOfAllowedVariablesForScatterPlots = <some int>;\""
00487             << gTools().Color("reset") << Endl;
00488       Log() << Endl;
00489       Log() << kINFO << "Some more output" << Endl;
00490    }
00491 
00492    Double_t timesRMS = gConfig().GetVariablePlotting().fTimesRMS;
00493    UInt_t   nbins1D  = gConfig().GetVariablePlotting().fNbins1D;
00494    UInt_t   nbins2D  = gConfig().GetVariablePlotting().fNbins2D;
00495 
00496    for (UInt_t var_tgt = 0; var_tgt < 2; var_tgt++) { // create the histos first for the variables, then for the targets
00497       UInt_t nloops = ( var_tgt == 0? nvar:ntgt );     // number of variables or number of targets
00498       for (UInt_t ivar=0; ivar<nloops; ivar++) {
00499          const VariableInfo& info = ( var_tgt == 0 ? Variable( ivar ) : Target(ivar) ); // choose the appropriate one (variable or target)
00500          TString myVari = info.GetInternalName();  
00501 
00502          Double_t mean = fVariableStats.at(fNumC-1).at( ( var_tgt*nvar )+ivar).fMean;
00503          Double_t rms  = fVariableStats.at(fNumC-1).at( ( var_tgt*nvar )+ivar).fRMS;
00504 
00505          for (Int_t cls = 0; cls < ncls; cls++) {
00506 
00507             TString className = fDataSetInfo.GetClassInfo(cls)->GetName();
00508 
00509             // add "target" in case of target variable (required for plotting macros)
00510             className += (ntgt == 1 && var_tgt == 1 ? "_target" : ""); 
00511 
00512             // choose reasonable histogram ranges, by removing outliers
00513             TH1* h = 0;
00514             if (info.GetVarType() == 'I') {
00515                // special treatment for integer variables
00516                Int_t xmin = TMath::Nint( GetMin( ( var_tgt*nvar )+ivar) );
00517                Int_t xmax = TMath::Nint( GetMax( ( var_tgt*nvar )+ivar) + 1 );
00518                Int_t nbins = xmax - xmin;
00519 
00520                h = new TH1F( Form("%s__%s%s", myVari.Data(), className.Data(), transfType.Data()), 
00521                              info.GetTitle(), nbins, xmin, xmax );
00522             }
00523             else {
00524                Double_t xmin = TMath::Max( GetMin( ( var_tgt*nvar )+ivar), mean - timesRMS*rms );
00525                Double_t xmax = TMath::Min( GetMax( ( var_tgt*nvar )+ivar), mean + timesRMS*rms );
00526       
00527                // protection
00528                if (xmin >= xmax) xmax = xmin*1.1; // try first...
00529                if (xmin >= xmax) xmax = xmin + 1; // this if xmin == xmax == 0
00530                // safety margin for values equal to the maximum within the histogram
00531                xmax += (xmax - xmin)/nbins1D;
00532 
00533                h = new TH1F( Form("%s__%s%s", myVari.Data(), className.Data(), transfType.Data()), 
00534                              info.GetTitle(), nbins1D, xmin, xmax );
00535             }
00536             
00537             h->GetXaxis()->SetTitle( gTools().GetXTitleWithUnit( GetVariableAxisTitle( info ), info.GetUnit() ) );
00538             h->GetYaxis()->SetTitle( gTools().GetYTitleWithUnit( *h, info.GetUnit(), kFALSE ) );
00539             hVars.at(cls).at((var_tgt*nvar)+ivar) = h;
00540    
00541             // profile and scatter plots
00542             if (nvar+ntgt <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
00543 
00544                for (UInt_t v_t = 0; v_t < 2; v_t++) {
00545                   UInt_t nl = ( v_t==0?nvar:ntgt );
00546                   UInt_t start = ( v_t==0? (var_tgt==0?ivar+1:0):(var_tgt==0?nl:ivar+1) );
00547                   for (UInt_t j=start; j<nl; j++) {
00548                      // choose the appropriate one (variable or target)
00549                      const VariableInfo& infoj = ( v_t == 0 ? Variable( j ) : Target(j) ); 
00550                      TString myVarj = infoj.GetInternalName();  
00551 
00552                      Double_t rxmin = fVariableStats.at(fNumC-1).at( ( v_t*nvar )+ivar).fMin;
00553                      Double_t rxmax = fVariableStats.at(fNumC-1).at( ( v_t*nvar )+ivar).fMax;
00554                      Double_t rymin = fVariableStats.at(fNumC-1).at( ( v_t*nvar )+j).fMin;
00555                      Double_t rymax = fVariableStats.at(fNumC-1).at( ( v_t*nvar )+j).fMax;
00556                      
00557                      // scatter plot
00558                      TH2F* h2 = new TH2F( Form( "scat_%s_vs_%s_%s%s" , myVarj.Data(), myVari.Data(), 
00559                                                 className.Data(), transfType.Data() ), 
00560                                           Form( "%s versus %s (%s)%s", infoj.GetTitle().Data(), info.GetTitle().Data(), 
00561                                                 className.Data(), transfType.Data() ), 
00562                                           nbins2D, rxmin , rxmax, 
00563                                           nbins2D, rymin , rymax );
00564 
00565                      h2->GetXaxis()->SetTitle( gTools().GetXTitleWithUnit( GetVariableAxisTitle( info  ), info .GetUnit() ) );
00566                      h2->GetYaxis()->SetTitle( gTools().GetXTitleWithUnit( GetVariableAxisTitle( infoj ), infoj.GetUnit() ) );
00567                      mycorr.at(cls).at((var_tgt*nvar)+ivar).at((v_t*nvar)+j) = h2;
00568                      
00569                      // profile plot
00570                      TProfile* p = new TProfile( Form( "prof_%s_vs_%s_%s%s", myVarj.Data(), 
00571                                                        myVari.Data(), className.Data(), 
00572                                                        transfType.Data() ), 
00573                                                  Form( "profile %s versus %s (%s)%s", 
00574                                                        infoj.GetTitle().Data(), info.GetTitle().Data(), 
00575                                                        className.Data(), transfType.Data() ), nbins1D, 
00576                                                  rxmin, rxmax );
00577                      //                                                 info.GetMin(), info.GetMax() );
00578 
00579                      p->GetXaxis()->SetTitle( gTools().GetXTitleWithUnit( GetVariableAxisTitle( info  ), info .GetUnit() ) );
00580                      p->GetYaxis()->SetTitle( gTools().GetXTitleWithUnit( GetVariableAxisTitle( infoj ), infoj.GetUnit() ) );
00581                      myprof.at(cls).at((var_tgt*nvar)+ivar).at((v_t*nvar)+j) = p;
00582                   }
00583                }
00584             }   
00585          }
00586       }
00587    }
00588 
00589    UInt_t nevts = events.size();
00590 
00591    // compute correlation coefficient between target value and variables (regression only)
00592    std::vector<Double_t> xregmean ( nvar+1, 0 );
00593    std::vector<Double_t> x2regmean( nvar+1, 0 );
00594    std::vector<Double_t> xCregmean( nvar+1, 0 );
00595 
00596    // fill the histograms (this approach should be faster than individual projection
00597    for (UInt_t ievt=0; ievt<nevts; ievt++) {
00598 
00599       const Event* ev = events[ievt];
00600 
00601       Float_t weight = ev->GetWeight();
00602       Int_t   cls    = ev->GetClass();
00603 
00604       // average correlation between first target and variables (so far only for single-target regression)
00605       if (ntgt == 1) {
00606          Float_t valr = ev->GetTarget(0);
00607          xregmean[nvar]  += valr;
00608          x2regmean[nvar] += valr*valr;
00609          for (UInt_t ivar=0; ivar<nvar; ivar++) {
00610             Float_t vali = ev->GetValue(ivar);
00611             xregmean[ivar]  += vali;
00612             x2regmean[ivar] += vali*vali;
00613             xCregmean[ivar] += vali*valr;
00614          }
00615       }
00616       
00617       // fill correlation histograms
00618       for (UInt_t var_tgt = 0; var_tgt < 2; var_tgt++) { // create the histos first for the variables, then for the targets
00619          UInt_t nloops = ( var_tgt == 0? nvar:ntgt );    // number of variables or number of targets
00620          for (UInt_t ivar=0; ivar<nloops; ivar++) {
00621             Float_t vali = ( var_tgt == 0 ? ev->GetValue(ivar) : ev->GetTarget(ivar) );
00622 
00623             // variable histos
00624             hVars.at(cls).at( ( var_tgt*nvar )+ivar)->Fill( vali, weight );
00625 
00626             // correlation histos
00627             if (nvar+ntgt <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
00628 
00629                for (UInt_t v_t = 0; v_t < 2; v_t++) {
00630                   UInt_t nl    = ( v_t==0 ? nvar : ntgt );
00631                   UInt_t start = ( v_t==0 ? (var_tgt==0?ivar+1:0) : (var_tgt==0?nl:ivar+1) );
00632                   for (UInt_t j=start; j<nl; j++) {
00633                      Float_t valj = ( v_t == 0 ? ev->GetValue(j) : ev->GetTarget(j) );
00634                      mycorr.at(cls).at( ( var_tgt*nvar )+ivar).at( ( v_t*nvar )+j)->Fill( vali, valj, weight );
00635                      myprof.at(cls).at( ( var_tgt*nvar )+ivar).at( ( v_t*nvar )+j)->Fill( vali, valj, weight );
00636                   }
00637                }
00638             }
00639          }
00640       }
00641    }
00642       
00643    // correlation analysis for ranking  (single-target regression only)
00644    if (ntgt == 1) {
00645       for (UInt_t ivar=0; ivar<=nvar; ivar++) {
00646          xregmean[ivar] /= nevts;
00647          x2regmean[ivar] = x2regmean[ivar]/nevts - xregmean[ivar]*xregmean[ivar];
00648       }
00649       for (UInt_t ivar=0; ivar<nvar; ivar++) {
00650          xCregmean[ivar] = xCregmean[ivar]/nevts - xregmean[ivar]*xregmean[nvar];
00651          xCregmean[ivar] /= TMath::Sqrt( x2regmean[ivar]*x2regmean[nvar] );
00652       }         
00653       
00654       fRanking.push_back( new Ranking( GetName() + "Transformation", "|Correlation with target|" ) );
00655       for (UInt_t ivar=0; ivar<nvar; ivar++) {   
00656          Double_t abscor = TMath::Abs( xCregmean[ivar] );
00657          fRanking.back()->AddRank( Rank( fDataSetInfo.GetVariableInfo(ivar).GetLabel(), abscor ) );
00658       }
00659 
00660       if (nvar+ntgt <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
00661       
00662          // compute also mutual information (non-linear correlation measure)
00663          fRanking.push_back( new Ranking( GetName() + "Transformation", "Mutual information" ) );
00664          for (UInt_t ivar=0; ivar<nvar; ivar++) {   
00665             TH2F* h1 = mycorr.at(0).at( nvar ).at( ivar );
00666             Double_t mi = gTools().GetMutualInformation( *h1 );
00667             fRanking.back()->AddRank( Rank( fDataSetInfo.GetVariableInfo(ivar).GetLabel(), mi ) );
00668          }     
00669          
00670          // compute correlation ratio (functional correlations measure)
00671          fRanking.push_back( new Ranking( GetName() + "Transformation", "Correlation Ratio" ) );
00672          for (UInt_t ivar=0; ivar<nvar; ivar++) {   
00673             TH2F*    h2 = mycorr.at(0).at( nvar ).at( ivar );
00674             Double_t cr = gTools().GetCorrelationRatio( *h2 );
00675             fRanking.back()->AddRank( Rank( fDataSetInfo.GetVariableInfo(ivar).GetLabel(), cr ) );
00676          } 
00677          
00678          // additionally compute correlation ratio from transposed histograms since correlation ratio is asymmetric
00679          fRanking.push_back( new Ranking( GetName() + "Transformation", "Correlation Ratio (T)" ) );
00680          for (UInt_t ivar=0; ivar<nvar; ivar++) {   
00681             TH2F*    h2T = gTools().TransposeHist( *mycorr.at(0).at( nvar ).at( ivar ) );
00682             Double_t cr  = gTools().GetCorrelationRatio( *h2T  );
00683             fRanking.back()->AddRank( Rank( fDataSetInfo.GetVariableInfo(ivar).GetLabel(), cr ) );
00684             delete h2T;
00685          }      
00686       }
00687    }
00688    // computes ranking of input variables
00689    // separation for 2-class classification
00690    else if (fDataSetInfo.GetNClasses() == 2 
00691             && fDataSetInfo.GetClassInfo("Signal") != NULL 
00692             && fDataSetInfo.GetClassInfo("Background") != NULL 
00693       ) { // TODO: ugly hack.. adapt to new framework
00694       fRanking.push_back( new Ranking( GetName() + "Transformation", "Separation" ) );
00695       for (UInt_t i=0; i<nvar; i++) {   
00696          Double_t sep = gTools().GetSeparation( hVars.at(fDataSetInfo.GetClassInfo("Signal")    ->GetNumber()).at(i), 
00697                                                 hVars.at(fDataSetInfo.GetClassInfo("Background")->GetNumber()).at(i) );
00698          fRanking.back()->AddRank( Rank( hVars.at(fDataSetInfo.GetClassInfo("Signal")->GetNumber()).at(i)->GetTitle(), 
00699                                          sep ) );
00700       }
00701    }
00702 
00703    // for regression compute performance from correlation with target value
00704 
00705    // write histograms
00706 
00707    TDirectory* localDir = theDirectory;
00708    if (theDirectory == 0) {
00709       // create directory in root dir
00710       fRootBaseDir->cd();
00711       TString outputDir = TString("InputVariables");
00712       TListIter trIt(&fTransformations);
00713       while (VariableTransformBase *trf = (VariableTransformBase*) trIt())
00714          outputDir += "_" + TString(trf->GetShortName());
00715 
00716       TObject* o = fRootBaseDir->FindObject(outputDir);
00717       if (o != 0) {
00718          Log() << kFATAL << "A " << o->ClassName() << " with name " << o->GetName() << " already exists in " 
00719                << fRootBaseDir->GetPath() << "("<<outputDir<<")" << Endl;
00720       }
00721       localDir = fRootBaseDir->mkdir( outputDir );
00722       localDir->cd();
00723    
00724       Log() << kVERBOSE << "Create and switch to directory " << localDir->GetPath() << Endl;
00725    }
00726    else {
00727       theDirectory->cd();
00728    }
00729 
00730    for (UInt_t i=0; i<nvar+ntgt; i++) {
00731       for (Int_t cls = 0; cls < ncls; cls++) {
00732          if (hVars.at(cls).at(i) != 0) {
00733             hVars.at(cls).at(i)->Write();
00734             hVars.at(cls).at(i)->SetDirectory(0);
00735             delete hVars.at(cls).at(i);
00736          }
00737       }
00738    }
00739 
00740    // correlation plots have dedicated directory
00741    if (nvar+ntgt <= (UInt_t)gConfig().GetVariablePlotting().fMaxNumOfAllowedVariablesForScatterPlots) {
00742 
00743       localDir = localDir->mkdir( "CorrelationPlots" );
00744       localDir ->cd();
00745       Log() << kINFO << "Create scatter and profile plots in target-file directory: " << Endl;
00746       Log() << kINFO << localDir->GetPath() << Endl;
00747    
00748       
00749       for (UInt_t i=0; i<nvar+ntgt; i++) {
00750          for (UInt_t j=i+1; j<nvar+ntgt; j++) {
00751             for (Int_t cls = 0; cls < ncls; cls++) {
00752                if (mycorr.at(cls).at(i).at(j) != 0 ) {
00753                   mycorr.at(cls).at(i).at(j)->Write();
00754                   mycorr.at(cls).at(i).at(j)->SetDirectory(0);
00755                   delete mycorr.at(cls).at(i).at(j);
00756                }
00757                if (myprof.at(cls).at(i).at(j) != 0) {
00758                   myprof.at(cls).at(i).at(j)->Write();
00759                   myprof.at(cls).at(i).at(j)->SetDirectory(0);
00760                   delete myprof.at(cls).at(i).at(j);
00761                }
00762             }
00763          }
00764       }
00765    }
00766    if (theDirectory != 0 ) theDirectory->cd();
00767    else                    fRootBaseDir->cd();
00768 }
00769 
00770 //_______________________________________________________________________
00771 std::vector<TString>* TMVA::TransformationHandler::GetTransformationStringsOfLastTransform() const
00772 {
00773    // returns string for transformation
00774    VariableTransformBase* trf = ((VariableTransformBase*)GetTransformationList().Last());
00775    if (!trf) return 0;
00776    else      return trf->GetTransformationStrings( fTransformationsReferenceClasses.back() );
00777 }
00778 
00779 //_______________________________________________________________________
00780 const char* TMVA::TransformationHandler::GetNameOfLastTransform() const
00781 {
00782    // returns string for transformation
00783    VariableTransformBase* trf = ((VariableTransformBase*)GetTransformationList().Last());
00784    if (!trf) return 0;
00785    else      return trf->GetName();
00786 }
00787 
00788 //_______________________________________________________________________
00789 void TMVA::TransformationHandler::WriteToStream( std::ostream& o ) const 
00790 {
00791    // write transformatino to stream
00792    TListIter trIt(&fTransformations);
00793    std::vector< Int_t >::const_iterator rClsIt = fTransformationsReferenceClasses.begin();
00794 
00795    o << "NTransformtations " << fTransformations.GetSize() << std::endl << std::endl;
00796 
00797    ClassInfo* ci;
00798    UInt_t i = 1;
00799    while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) {
00800       o << "#TR -*-*-*-*-*-*-* transformation " << i++ << ": " << trf->GetName() << " -*-*-*-*-*-*-*-" << std::endl;
00801       trf->WriteTransformationToStream(o);
00802       ci = fDataSetInfo.GetClassInfo( (*rClsIt) );
00803       TString clsName;
00804       if (ci == 0 ) clsName = "AllClasses";
00805       else clsName = ci->GetName();
00806       o << "ReferenceClass " << clsName << std::endl; 
00807       rClsIt++;
00808    }
00809 }
00810 
00811 
00812 //_______________________________________________________________________
00813 void TMVA::TransformationHandler::AddXMLTo( void* parent ) const 
00814 {
00815    // XML node describing the transformation
00816    //   return;
00817    if(!parent) return;
00818    void* trfs = gTools().AddChild(parent, "Transformations");
00819    gTools().AddAttr( trfs, "NTransformations", fTransformations.GetSize() );
00820    TListIter trIt(&fTransformations);
00821    while (VariableTransformBase *trf = (VariableTransformBase*) trIt()) trf->AttachXMLTo(trfs);
00822 }
00823 
00824 //_______________________________________________________________________
00825 void TMVA::TransformationHandler::ReadFromStream( std::istream& ) 
00826 {
00827    //VariableTransformBase* trf = ((VariableTransformBase*)GetTransformationList().Last());
00828    //trf->ReadTransformationFromStream(fin);
00829    Log() << kFATAL << "Read transformations not implemented" << Endl;
00830    // TODO
00831 }
00832 
00833 //_______________________________________________________________________
00834 void TMVA::TransformationHandler::ReadFromXML( void* trfsnode )
00835 {
00836    void* ch = gTools().GetChild( trfsnode );
00837    while(ch) {
00838       Int_t idxCls = -1;
00839       TString trfname;
00840       gTools().ReadAttr(ch, "Name", trfname);
00841 
00842       VariableTransformBase* newtrf = 0;
00843 
00844       if (trfname == "Decorrelation" ) {
00845          newtrf = new VariableDecorrTransform(fDataSetInfo);
00846       }
00847       else if (trfname == "PCA" ) {
00848          newtrf = new VariablePCATransform(fDataSetInfo);
00849       }
00850       else if (trfname == "Gauss" ) {
00851          newtrf = new VariableGaussTransform(fDataSetInfo);
00852       }
00853       else if (trfname == "Normalize" ) {
00854          newtrf = new VariableNormalizeTransform(fDataSetInfo);
00855       }
00856       else {
00857          Log() << kFATAL << "<ReadFromXML> Variable transform '"
00858                << trfname << "' unknown." << Endl;
00859       }
00860       newtrf->ReadFromXML( ch );
00861       AddTransformation( newtrf, idxCls );
00862       ch = gTools().GetNextChild(ch);
00863    }
00864 }
00865 
00866 //_______________________________________________________________________
00867 void TMVA::TransformationHandler::PrintVariableRanking() const
00868 {
00869    // prints ranking of input variables
00870    Log() << kINFO << " " << Endl;
00871    Log() << kINFO << "Ranking input variables (method unspecific)..." << Endl;
00872    std::vector<Ranking*>::const_iterator it = fRanking.begin();
00873    for (; it != fRanking.end(); it++) (*it)->Print();
00874 }
00875 
00876 //_______________________________________________________________________
00877 Double_t TMVA::TransformationHandler::GetMean( Int_t ivar, Int_t cls ) const
00878 {
00879    try {
00880       return fVariableStats.at(cls).at(ivar).fMean;
00881    }
00882    catch(...) {
00883       try {
00884          return fVariableStats.at(fNumC-1).at(ivar).fMean;
00885       }
00886       catch(...) {
00887          Log() << kWARNING << "Inconsistent variable state when reading the mean value. " << Endl;
00888       }
00889    }
00890    Log() << kWARNING << "Inconsistent variable state when reading the mean value. Value 0 given back" << Endl;
00891    return 0;
00892 }
00893 
00894 
00895 //_______________________________________________________________________
00896 Double_t TMVA::TransformationHandler::GetRMS( Int_t ivar, Int_t cls ) const
00897 {
00898    try {
00899       return fVariableStats.at(cls).at(ivar).fRMS;
00900    }
00901    catch(...) {
00902       try {
00903          return fVariableStats.at(fNumC-1).at(ivar).fRMS;
00904       }
00905       catch(...) {
00906          Log() << kWARNING << "Inconsistent variable state when reading the RMS value. " << Endl;
00907       }
00908    }
00909    Log() << kWARNING << "Inconsistent variable state when reading the RMS value. Value 0 given back" << Endl;
00910    return 0;
00911 }
00912 
00913 //_______________________________________________________________________
00914 Double_t TMVA::TransformationHandler::GetMin( Int_t ivar, Int_t cls ) const
00915 {
00916    try {
00917       return fVariableStats.at(cls).at(ivar).fMin;
00918    }
00919    catch(...) {
00920       try {
00921          return fVariableStats.at(fNumC-1).at(ivar).fMin;
00922       }
00923       catch(...) {
00924          Log() << kWARNING << "Inconsistent variable state when reading the minimum value. " << Endl;
00925       }
00926    }
00927    Log() << kWARNING << "Inconsistent variable state when reading the minimum value. Value 0 given back" << Endl;
00928    return 0;
00929 }
00930 
00931 //_______________________________________________________________________
00932 Double_t TMVA::TransformationHandler::GetMax( Int_t ivar, Int_t cls ) const
00933 {
00934    try {
00935       return fVariableStats.at(cls).at(ivar).fMax;
00936    }
00937    catch(...) {
00938       try {
00939          return fVariableStats.at(fNumC-1).at(ivar).fMax;
00940       }
00941       catch(...) {
00942          Log() << kWARNING << "Inconsistent variable state when reading the maximum value. " << Endl;
00943       }
00944    }
00945    Log() << kWARNING << "Inconsistent variable state when reading the maximum value. Value 0 given back" << Endl;
00946    return 0;
00947 }

Generated on Tue Jul 5 15:25:37 2011 for ROOT_528-00b_version by  doxygen 1.5.1