ResultsMulticlass.cxx

Go to the documentation of this file.
00001 // @(#)root/tmva $Id: ResultsMulticlass.cxx 37138 2010-12-01 10:01:16Z evt $
00002 // Author: Andreas Hoecker, Peter Speckmayer, Joerg Stelzer, Helge Voss
00003 
00004 /**********************************************************************************
00005  * Project: TMVA - a Root-integrated toolkit for multivariate data analysis       *
00006  * Package: TMVA                                                                  *
00007  * Class  : ResultsMulticlass                                                     *
00008  * Web    : http://tmva.sourceforge.net                                           *
00009  *                                                                                *
00010  * Description:                                                                   *
00011  *      Implementation (see header for description)                               *
00012  *                                                                                *
00013  * Authors (alphabetical):                                                        *
00014  *      Andreas Hoecker <Andreas.Hocker@cern.ch> - CERN, Switzerland              *
00015  *      Peter Speckmayer <Peter.Speckmayer@cern.ch>  - CERN, Switzerland          *
00016  *      Joerg Stelzer   <Joerg.Stelzer@cern.ch>  - CERN, Switzerland              *
00017  *      Helge Voss      <Helge.Voss@cern.ch>     - MPI-K Heidelberg, Germany      *
00018  *                                                                                *
00019  * Copyright (c) 2006:                                                            *
00020  *      CERN, Switzerland                                                         *
00021  *      MPI-K Heidelberg, Germany                                                 *
00022  *                                                                                *
00023  * Redistribution and use in source and binary forms, with or without             *
00024  * modification, are permitted according to the terms listed in LICENSE           *
00025  * (http://tmva.sourceforge.net/LICENSE)                                          *
00026  **********************************************************************************/
00027 
00028 #include <vector>
00029 
00030 #include "TMVA/ResultsMulticlass.h"
00031 #include "TMVA/MsgLogger.h"
00032 #include "TMVA/DataSet.h"
00033 #include "TMVA/Tools.h"
00034 #include "TMVA/GeneticAlgorithm.h"
00035 #include "TMVA/GeneticFitter.h"
00036 
00037 //_______________________________________________________________________
00038 TMVA::ResultsMulticlass::ResultsMulticlass( const DataSetInfo* dsi ) 
00039    : Results( dsi ),
00040      IFitterTarget(),
00041      fLogger( new MsgLogger("ResultsMulticlass", kINFO) ),
00042      fClassToOptimize(0),
00043      fAchievableEff(dsi->GetNClasses()),
00044      fAchievablePur(dsi->GetNClasses()),
00045      fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
00046 {
00047    // constructor
00048 }
00049 
00050 //_______________________________________________________________________
00051 TMVA::ResultsMulticlass::~ResultsMulticlass() 
00052 {
00053    // destructor
00054    delete fLogger;
00055 }
00056 
00057 //_______________________________________________________________________
00058 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
00059 {
00060    if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
00061    fMultiClassValues[ievt] = value; 
00062 }
00063  
00064 //_______________________________________________________________________
00065  
00066 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
00067    
00068    DataSet* ds = GetDataSet();
00069    ds->SetCurrentType( GetTreeType() );
00070    Float_t truePositive = 0;
00071    Float_t falsePositive = 0;
00072    Float_t sumWeights = 0;
00073  
00074    for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
00075       Event* ev = ds->GetEvent(ievt);
00076       Float_t w = ev->GetWeight();
00077       if(ev->GetClass()==fClassToOptimize)
00078          sumWeights += w;
00079       bool passed = true;
00080       for(UInt_t icls = 0; icls<cutvalues.size(); ++icls){
00081          if(cutvalues.at(icls)<0. ? -fMultiClassValues[ievt][icls]<cutvalues.at(icls) : fMultiClassValues[ievt][icls]<=cutvalues.at(icls)){
00082             passed = false;
00083             break;
00084          }
00085       }
00086       if(!passed)
00087          continue;
00088       if(ev->GetClass()==fClassToOptimize)
00089          truePositive += w;
00090       else
00091          falsePositive += w;
00092    }
00093    
00094    Float_t eff = truePositive/sumWeights;
00095    Float_t pur = truePositive/(truePositive+falsePositive);
00096    Float_t effTimesPur = eff*pur;
00097    
00098    Float_t toMinimize = std::numeric_limits<float>::max();
00099    if( effTimesPur > 0 )
00100       toMinimize = 1./(effTimesPur); // we want to minimize 1/efficiency*purity
00101 
00102    fAchievableEff.at(fClassToOptimize) = eff;
00103    fAchievablePur.at(fClassToOptimize) = pur;
00104 
00105    return toMinimize;
00106 }
00107 
00108 //_______________________________________________________________________
00109 
00110 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
00111 
00112    //calculate the best working point (optimal cut values)
00113    //for the multiclass classifier
00114    const DataSetInfo* dsi = GetDataSetInfo();
00115    Log() << kINFO << "Calculating best set of cuts for class " 
00116          << dsi->GetClassInfo( targetClass )->GetName() << Endl;
00117   
00118    fClassToOptimize = targetClass;
00119    std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
00120    
00121    const TString name( "MulticlassGA" );
00122    const TString opts( "PopSize=100:Steps=30" );
00123    GeneticFitter mg( *this, name, ranges, opts);
00124    
00125    std::vector<Double_t> result;
00126    mg.Run(result);
00127 
00128    fBestCuts.at(targetClass) = result;
00129   
00130    UInt_t n = 0;
00131    for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
00132       Log() << kINFO << "  cutValue[" <<dsi->GetClassInfo( n )->GetName()  << "] = " << (*it) << ";"<< Endl;
00133       n++;
00134         }
00135    
00136    return result;
00137 }
00138 
00139 //_______________________________________________________________________
00140 
00141 void  TMVA::ResultsMulticlass::CreateMulticlassHistos( TString prefix, Int_t nbins, Int_t /* nbins_high */ )
00142 {
00143    //this function fills the mva response histos for multiclass classification
00144    Log() << kINFO << "Creating multiclass response histograms..." << Endl;
00145       
00146    DataSet* ds = GetDataSet();
00147    ds->SetCurrentType( GetTreeType() );
00148    const DataSetInfo* dsi = GetDataSetInfo();
00149    
00150    std::vector<std::vector<TH1F*> > histos;
00151    Float_t xmin = 0.-0.0002;
00152    Float_t xmax = 1.+0.0002;
00153    for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
00154       histos.push_back(std::vector<TH1F*>(0));
00155       for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00156          TString name(Form("%s_%s_prob_for_%s",prefix.Data(),
00157                            dsi->GetClassInfo( jCls )->GetName().Data(),
00158                            dsi->GetClassInfo( iCls )->GetName().Data()));
00159          histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
00160       }
00161    }
00162 
00163    for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
00164       Event* ev = ds->GetEvent(ievt);
00165       Int_t cls = ev->GetClass();
00166       Float_t w = ev->GetWeight();
00167       for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00168          histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
00169       }
00170    }
00171    for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
00172       for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00173          gTools().NormHist( histos.at(iCls).at(jCls) );
00174          Store(histos.at(iCls).at(jCls));
00175       }
00176    }
00177 
00178    /*
00179    //fill fine binned histos for testing
00180    if(prefix.Contains("Test")){
00181       std::vector<std::vector<TH1F*> > histos_highbin;
00182       for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
00183          histos_highbin.push_back(std::vector<TH1F*>(0));
00184          for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00185             TString name(Form("%s_%s_prob_for_%s_HIGHBIN",prefix.Data(),
00186                               dsi->GetClassInfo( jCls )->GetName().Data(),
00187                               dsi->GetClassInfo( iCls )->GetName().Data()));
00188             histos_highbin.at(iCls).push_back(new TH1F(name,name,nbins_high,xmin,xmax));
00189          }
00190       }
00191       
00192       for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
00193          Event* ev = ds->GetEvent(ievt);
00194          Int_t cls = ev->GetClass();
00195          Float_t w = ev->GetWeight();
00196          for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00197             histos_highbin.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
00198          }
00199       }
00200       for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
00201          for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00202             gTools().NormHist( histos_highbin.at(iCls).at(jCls) );
00203             Store(histos_highbin.at(iCls).at(jCls));
00204          }
00205       }
00206    }
00207    */
00208 }

Generated on Tue Jul 5 15:25:33 2011 for ROOT_528-00b_version by  doxygen 1.5.1