VariableTransformBase.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: VariableTransformBase.cxx 33993 2010-06-19 11:25:14Z stelzer $
00002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : VariableTransformBase                                                 *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header for description)                               *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00016  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00017  *                                                                                *
00018  * Copyright (c) 2005:                                                            *
00019  *      CERN, Switzerland                                                         *
00020  *      MPI-K Heidelberg, Germany                                                 *
00021  *                                                                                *
00022  * Redistribution and use in source and binary forms, with or without             *
00023  * modification, are permitted according to the terms listed in LICENSE           *
00024  * (http://tmva.sourceforge.net/LICENSE)                                          *
00025  **********************************************************************************/
00026 
00027 #include <iomanip>
00028 
00029 #include "TMath.h"
00030 #include "TVectorD.h"
00031 #include "TH1.h"
00032 #include "TH2.h"
00033 #include "TProfile.h"
00034 
00035 #include "TMVA/VariableTransformBase.h"
00036 #include "TMVA/Ranking.h"
00037 #include "TMVA/Config.h"
00038 #include "TMVA/Tools.h"
00039 #include "TMVA/Version.h"
00040 
00041 #ifndef ROOT_TMVA_MsgLogger
00042 #include "TMVA/MsgLogger.h"
00043 #endif
00044 
00045 ClassImp(TMVA::VariableTransformBase)
00046 
00047 //_______________________________________________________________________
00048 TMVA::VariableTransformBase::VariableTransformBase( DataSetInfo& dsi,
00049                                                     Types::EVariableTransform tf,
00050                                                     const TString& trfName )
00051    : TObject(),
00052      fDsi(dsi),
00053      fTransformedEvent(0),
00054      fBackTransformedEvent(0),
00055      fVariableTransform(tf),
00056      fEnabled( kTRUE ),
00057      fCreated( kFALSE ),
00058      fNormalise( kFALSE ),
00059      fTransformName(trfName),
00060      fTMVAVersion(TMVA_VERSION_CODE),
00061      fLogger( 0 )
00062 {
00063    // standard constructor
00064    fLogger = new MsgLogger(this, kINFO);
00065    for (UInt_t ivar = 0; ivar < fDsi.GetNVariables(); ivar++) {
00066       fVariables.push_back( VariableInfo( fDsi.GetVariableInfo(ivar) ) );
00067    }
00068    for (UInt_t itgt = 0; itgt < fDsi.GetNTargets(); itgt++) {
00069       fTargets.push_back( VariableInfo( fDsi.GetTargetInfo(itgt) ) );
00070    }
00071 }
00072 
00073 //_______________________________________________________________________
00074 TMVA::VariableTransformBase::~VariableTransformBase()
00075 {
00076    if (fTransformedEvent!=0)     delete fTransformedEvent;
00077    if (fBackTransformedEvent!=0) delete fBackTransformedEvent;
00078    // destructor
00079    delete fLogger;
00080 }
00081 
00082 //_______________________________________________________________________
00083 void TMVA::VariableTransformBase::CalcNorm( const std::vector<Event*>& events ) 
00084 {
00085    // method to calculate minimum, maximum, mean, and RMS for all
00086    // variables used in the MVA
00087 
00088    if (!IsCreated()) return;
00089 
00090    const UInt_t nvars = GetNVariables();
00091    const UInt_t ntgts = GetNTargets();
00092 
00093    UInt_t nevts = events.size();
00094 
00095    TVectorD x2( nvars+ntgts ); x2 *= 0;
00096    TVectorD x0( nvars+ntgts ); x0 *= 0;   
00097 
00098    Double_t sumOfWeights = 0;
00099    for (UInt_t ievt=0; ievt<nevts; ievt++) {
00100       const Event* ev = events[ievt];
00101 
00102       Double_t weight = ev->GetWeight();
00103       sumOfWeights += weight;
00104       for (UInt_t ivar=0; ivar<nvars; ivar++) {
00105          Double_t x = ev->GetValue(ivar);
00106          if (ievt==0) {
00107             Variables().at(ivar).SetMin(x);
00108             Variables().at(ivar).SetMax(x);
00109          } 
00110          else {
00111             UpdateNorm( ivar,  x );
00112          }
00113          x0(ivar) += x*weight;
00114          x2(ivar) += x*x*weight;
00115       }
00116       for (UInt_t itgt=0; itgt<ntgts; itgt++) {
00117          Double_t x = ev->GetTarget(itgt);
00118          if (ievt==0) {
00119             Targets().at(itgt).SetMin(x);
00120             Targets().at(itgt).SetMax(x);
00121          } 
00122          else {
00123             UpdateNorm( nvars+itgt,  x );
00124          }
00125          x0(nvars+itgt) += x*weight;
00126          x2(nvars+itgt) += x*x*weight;
00127       }
00128    }
00129 
00130    if (sumOfWeights <= 0) {
00131       Log() << kFATAL << " the sum of event weights calcualted for your input is == 0"
00132             << " or exactly: " << sumOfWeights << " there is obviously some problem..."<< Endl;
00133    } 
00134 
00135    // set Mean and RMS
00136    for (UInt_t ivar=0; ivar<nvars; ivar++) {
00137       Double_t mean = x0(ivar)/sumOfWeights;
00138       
00139       Variables().at(ivar).SetMean( mean ); 
00140       if (x2(ivar)/sumOfWeights - mean*mean < 0) {
00141          Log() << kFATAL << " the RMS of your input variable " << ivar 
00142                << " evaluates to an imaginary number: sqrt("<< x2(ivar)/sumOfWeights - mean*mean
00143                <<") .. sometimes related to a problem with outliers and negative event weights"
00144                << Endl;
00145       }
00146       Variables().at(ivar).SetRMS( TMath::Sqrt( x2(ivar)/sumOfWeights - mean*mean) );
00147    }
00148    for (UInt_t itgt=0; itgt<ntgts; itgt++) {
00149       Double_t mean = x0(nvars+itgt)/sumOfWeights;
00150       Targets().at(itgt).SetMean( mean ); 
00151       if (x2(nvars+itgt)/sumOfWeights - mean*mean < 0) {
00152          Log() << kFATAL << " the RMS of your target variable " << itgt 
00153                << " evaluates to an imaginary number: sqrt(" << x2(nvars+itgt)/sumOfWeights - mean*mean
00154                <<") .. sometimes related to a problem with outliers and negative event weights"
00155                << Endl;
00156       }
00157       Targets().at(itgt).SetRMS( TMath::Sqrt( x2(nvars+itgt)/sumOfWeights - mean*mean) );
00158    }
00159 
00160    Log() << kVERBOSE << "Set minNorm/maxNorm for variables to: " << Endl;
00161    Log() << std::setprecision(3);
00162    for (UInt_t ivar=0; ivar<GetNVariables(); ivar++)
00163       Log() << "    " << Variables().at(ivar).GetInternalName()
00164               << "\t: [" << Variables().at(ivar).GetMin() << "\t, " << Variables().at(ivar).GetMax() << "\t] " << Endl;
00165    Log() << kVERBOSE << "Set minNorm/maxNorm for targets to: " << Endl;
00166    Log() << std::setprecision(3);
00167    for (UInt_t itgt=0; itgt<GetNTargets(); itgt++)
00168       Log() << "    " << Targets().at(itgt).GetInternalName()
00169               << "\t: [" << Targets().at(itgt).GetMin() << "\t, " << Targets().at(itgt).GetMax() << "\t] " << Endl;
00170    Log() << std::setprecision(5); // reset to better value       
00171 }
00172 
00173 //_______________________________________________________________________
00174 std::vector<TString>* TMVA::VariableTransformBase::GetTransformationStrings( Int_t /*cls*/ ) const
00175 {
00176    // default transformation output
00177    // --> only indicate that transformation occurred
00178    std::vector<TString>* strVec = new std::vector<TString>;
00179    for (UInt_t ivar=0; ivar<GetNVariables(); ivar++) {
00180       strVec->push_back( Variables()[ivar].GetLabel() + "_[transformed]");
00181    }
00182 
00183    return strVec;   
00184 }
00185 
00186 //_______________________________________________________________________
00187 void TMVA::VariableTransformBase::UpdateNorm ( Int_t ivar,  Double_t x ) 
00188 {
00189    // update min and max of a given variable (target) and a given transformation method
00190    Int_t nvars = fDsi.GetNVariables();
00191    if( ivar < nvars ){
00192       if (x < Variables().at(ivar).GetMin()) Variables().at(ivar).SetMin(x);
00193       if (x > Variables().at(ivar).GetMax()) Variables().at(ivar).SetMax(x);
00194    }else{
00195       if (x < Targets().at(ivar-nvars).GetMin()) Targets().at(ivar-nvars).SetMin(x);
00196       if (x > Targets().at(ivar-nvars).GetMax()) Targets().at(ivar-nvars).SetMax(x);
00197    }
00198 }
00199 

Generated on Tue Jul 5 15:25:38 2011 for ROOT_528-00b_version by  doxygen 1.5.1