00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028 #include <vector>
00029
00030 #include "TMVA/ResultsMulticlass.h"
00031 #include "TMVA/MsgLogger.h"
00032 #include "TMVA/DataSet.h"
00033 #include "TMVA/Tools.h"
00034 #include "TMVA/GeneticAlgorithm.h"
00035 #include "TMVA/GeneticFitter.h"
00036
00037
00038 TMVA::ResultsMulticlass::ResultsMulticlass( const DataSetInfo* dsi )
00039 : Results( dsi ),
00040 IFitterTarget(),
00041 fLogger( new MsgLogger("ResultsMulticlass", kINFO) ),
00042 fClassToOptimize(0),
00043 fAchievableEff(dsi->GetNClasses()),
00044 fAchievablePur(dsi->GetNClasses()),
00045 fBestCuts(dsi->GetNClasses(),std::vector<Double_t>(dsi->GetNClasses()))
00046 {
00047
00048 }
00049
00050
00051 TMVA::ResultsMulticlass::~ResultsMulticlass()
00052 {
00053
00054 delete fLogger;
00055 }
00056
00057
00058 void TMVA::ResultsMulticlass::SetValue( std::vector<Float_t>& value, Int_t ievt )
00059 {
00060 if (ievt >= (Int_t)fMultiClassValues.size()) fMultiClassValues.resize( ievt+1 );
00061 fMultiClassValues[ievt] = value;
00062 }
00063
00064
00065
00066 Double_t TMVA::ResultsMulticlass::EstimatorFunction( std::vector<Double_t> & cutvalues ){
00067
00068 DataSet* ds = GetDataSet();
00069 ds->SetCurrentType( GetTreeType() );
00070 Float_t truePositive = 0;
00071 Float_t falsePositive = 0;
00072 Float_t sumWeights = 0;
00073
00074 for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
00075 Event* ev = ds->GetEvent(ievt);
00076 Float_t w = ev->GetWeight();
00077 if(ev->GetClass()==fClassToOptimize)
00078 sumWeights += w;
00079 bool passed = true;
00080 for(UInt_t icls = 0; icls<cutvalues.size(); ++icls){
00081 if(cutvalues.at(icls)<0. ? -fMultiClassValues[ievt][icls]<cutvalues.at(icls) : fMultiClassValues[ievt][icls]<=cutvalues.at(icls)){
00082 passed = false;
00083 break;
00084 }
00085 }
00086 if(!passed)
00087 continue;
00088 if(ev->GetClass()==fClassToOptimize)
00089 truePositive += w;
00090 else
00091 falsePositive += w;
00092 }
00093
00094 Float_t eff = truePositive/sumWeights;
00095 Float_t pur = truePositive/(truePositive+falsePositive);
00096 Float_t effTimesPur = eff*pur;
00097
00098 Float_t toMinimize = std::numeric_limits<float>::max();
00099 if( effTimesPur > 0 )
00100 toMinimize = 1./(effTimesPur);
00101
00102 fAchievableEff.at(fClassToOptimize) = eff;
00103 fAchievablePur.at(fClassToOptimize) = pur;
00104
00105 return toMinimize;
00106 }
00107
00108
00109
00110 std::vector<Double_t> TMVA::ResultsMulticlass::GetBestMultiClassCuts(UInt_t targetClass){
00111
00112
00113
00114 const DataSetInfo* dsi = GetDataSetInfo();
00115 Log() << kINFO << "Calculating best set of cuts for class "
00116 << dsi->GetClassInfo( targetClass )->GetName() << Endl;
00117
00118 fClassToOptimize = targetClass;
00119 std::vector<Interval*> ranges(dsi->GetNClasses(), new Interval(-1,1));
00120
00121 const TString name( "MulticlassGA" );
00122 const TString opts( "PopSize=100:Steps=30" );
00123 GeneticFitter mg( *this, name, ranges, opts);
00124
00125 std::vector<Double_t> result;
00126 mg.Run(result);
00127
00128 fBestCuts.at(targetClass) = result;
00129
00130 UInt_t n = 0;
00131 for( std::vector<Double_t>::iterator it = result.begin(); it<result.end(); it++ ){
00132 Log() << kINFO << " cutValue[" <<dsi->GetClassInfo( n )->GetName() << "] = " << (*it) << ";"<< Endl;
00133 n++;
00134 }
00135
00136 return result;
00137 }
00138
00139
00140
00141 void TMVA::ResultsMulticlass::CreateMulticlassHistos( TString prefix, Int_t nbins, Int_t )
00142 {
00143
00144 Log() << kINFO << "Creating multiclass response histograms..." << Endl;
00145
00146 DataSet* ds = GetDataSet();
00147 ds->SetCurrentType( GetTreeType() );
00148 const DataSetInfo* dsi = GetDataSetInfo();
00149
00150 std::vector<std::vector<TH1F*> > histos;
00151 Float_t xmin = 0.-0.0002;
00152 Float_t xmax = 1.+0.0002;
00153 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
00154 histos.push_back(std::vector<TH1F*>(0));
00155 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00156 TString name(Form("%s_%s_prob_for_%s",prefix.Data(),
00157 dsi->GetClassInfo( jCls )->GetName().Data(),
00158 dsi->GetClassInfo( iCls )->GetName().Data()));
00159 histos.at(iCls).push_back(new TH1F(name,name,nbins,xmin,xmax));
00160 }
00161 }
00162
00163 for (Int_t ievt=0; ievt<ds->GetNEvents(); ievt++) {
00164 Event* ev = ds->GetEvent(ievt);
00165 Int_t cls = ev->GetClass();
00166 Float_t w = ev->GetWeight();
00167 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00168 histos.at(cls).at(jCls)->Fill(fMultiClassValues[ievt][jCls],w);
00169 }
00170 }
00171 for (UInt_t iCls = 0; iCls < dsi->GetNClasses(); iCls++) {
00172 for (UInt_t jCls = 0; jCls < dsi->GetNClasses(); jCls++) {
00173 gTools().NormHist( histos.at(iCls).at(jCls) );
00174 Store(histos.at(iCls).at(jCls));
00175 }
00176 }
00177
00178
00179
00180
00181
00182
00183
00184
00185
00186
00187
00188
00189
00190
00191
00192
00193
00194
00195
00196
00197
00198
00199
00200
00201
00202
00203
00204
00205
00206
00207
00208 }