00001 #include "TLegend.h"
00002 #include "TText.h"
00003 #include "TH2.h"
00004
00005 #include "tmvaglob.C"
00006
00007
00008
00009
00010
00011 enum HistType { MVAType = 0, CompareType = 1 };
00012
00013
00014
00015 void mvasMulticlass( TString fin = "TMVAMulticlass.root", HistType htype = MVAType, Bool_t useTMVAStyle = kTRUE )
00016 {
00017
00018 TMVAGlob::Initialize( useTMVAStyle );
00019
00020
00021 const Bool_t Save_Images = kTRUE;
00022
00023
00024 TFile* file = TMVAGlob::OpenFile( fin );
00025
00026 TDirectory* tempdir = (TDirectory*)file->Get("InputVariables_Id" );
00027 std::vector<TString> classnames(TMVAGlob::GetClassNames(tempdir));
00028
00029
00030 Int_t xPad = 1;
00031 Int_t yPad = 1;
00032 Int_t noPad = xPad * yPad ;
00033 const Int_t width = 600;
00034
00035
00036 TCanvas *c = 0;
00037
00038
00039 Int_t countCanvas = 0;
00040
00041
00042 TIter next(file->GetListOfKeys());
00043 TKey *key(0);
00044 while ((key = (TKey*)next())) {
00045
00046 if (!TString(key->GetName()).BeginsWith("Method_")) continue;
00047 if (!gROOT->GetClass(key->GetClassName())->InheritsFrom("TDirectory")) continue;
00048
00049 TString methodName;
00050 TMVAGlob::GetMethodName(methodName,key);
00051
00052 TDirectory* mDir = (TDirectory*)key->ReadObj();
00053
00054 TIter keyIt(mDir->GetListOfKeys());
00055 TKey *titkey;
00056 while ((titkey = (TKey*)keyIt())) {
00057
00058 if (!gROOT->GetClass(titkey->GetClassName())->InheritsFrom("TDirectory")) continue;
00059
00060 TDirectory *titDir = (TDirectory *)titkey->ReadObj();
00061 TString methodTitle;
00062 TMVAGlob::GetMethodTitle(methodTitle,titDir);
00063
00064 cout << "--- Found directory for method: " << methodName << "::" << methodTitle << endl;
00065 TString hname = "MVA_" + methodTitle;
00066 for(Int_t icls = 0; icls < classnames.size(); ++icls){
00067 TObjArray hists;
00068
00069 std::vector<TString>::iterator classiter = classnames.begin();
00070 for(; classiter!=classnames.end(); ++classiter){
00071 TString name(hname+"_Test_"+ classnames.at(icls)
00072 + "_prob_for_" + *classiter);
00073 TH1 *hist = (TH1*)titDir->Get(name);
00074 if (hist==0){
00075 cout << ":\t mva distribution not available (this is normal for Cut classifier)" << endl;
00076 continue;
00077 }
00078 hists.Add(hist);
00079 }
00080
00081
00082
00083 ((TH1*)hists.First())->SetTitle( Form("TMVA response for classifier: %s", methodTitle.Data() ));
00084
00085
00086
00087 TString ctitle = ((htype == MVAType) ?
00088 Form("TMVA response for class %s %s", classnames.at(icls).Data(),methodTitle.Data()) :
00089 Form("TMVA comparison for class %s %s", classnames.at(icls).Data(),methodTitle.Data())) ;
00090
00091 c = new TCanvas( Form("canvas%d", countCanvas+1), ctitle,
00092 countCanvas*50+200, countCanvas*20, width, (Int_t)width*0.78 );
00093
00094
00095
00096 TMVAGlob::SetMultiClassStyle( &hists );
00097
00098
00099 Float_t histmax = -1;
00100 for(Int_t i=0; i<hists.GetEntriesFast(); ++i){
00101 TMVAGlob::NormalizeHist((TH1*)hists[i] );
00102 if(((TH1*)hists[i])->GetMaximum() > histmax)
00103 histmax = ((TH1*)hists[i])->GetMaximum();
00104 }
00105
00106
00107 Float_t xmin = 0;
00108 Float_t xmax = 1;
00109 Float_t ymin = 0;
00110 Float_t maxMult = (htype == CompareType) ? 1.3 : 1.2;
00111 Float_t ymax = histmax*maxMult;
00112
00113 Int_t nb = 500;
00114 TString hFrameName(TString("frame") + methodTitle);
00115 TObject *o = gROOT->FindObject(hFrameName);
00116 if(o) delete o;
00117 TH2F* frame = new TH2F( hFrameName, ((TH1*)hists.First())->GetTitle(),
00118 nb, xmin, xmax, nb, ymin, ymax );
00119 frame->GetXaxis()->SetTitle( methodTitle + " response for "+classnames.at(icls));
00120 frame->GetYaxis()->SetTitle("(1/N) dN^{ }/^{ }dx");
00121 TMVAGlob::SetFrameStyle( frame );
00122
00123
00124 frame->Draw();
00125
00126 c->GetPad(0)->SetLeftMargin( 0.105 );
00127 frame->GetYaxis()->SetTitleOffset( 1.2 );
00128
00129
00130 TLegend *legend= new TLegend( c->GetLeftMargin(), 1 - c->GetTopMargin() - 0.12,
00131 c->GetLeftMargin() + (htype == CompareType ? 0.40 : 0.3), 1 - c->GetTopMargin() );
00132 legend->SetFillStyle( 1 );
00133 classiter = classnames.begin();
00134
00135 for(Int_t i=0; i<hists.GetEntriesFast(); ++i, ++classiter){
00136 legend->AddEntry(((TH1*)hists[i]),*classiter,"F");
00137 }
00138
00139 legend->SetBorderSize(1);
00140 legend->SetMargin( 0.3 );
00141 legend->Draw("same");
00142
00143
00144 for(Int_t i=0; i<hists.GetEntriesFast(); ++i){
00145
00146 ((TH1*)hists[i])->Draw("histsame");
00147 TString ytit = TString("(1/N) ") + ((TH1*)hists[i])->GetYaxis()->GetTitle();
00148 ((TH1*)hists[i])->GetYaxis()->SetTitle( ytit );
00149
00150 }
00151
00152
00153 if (htype == CompareType) {
00154
00155 TObjArray othists;
00156
00157 classiter = classnames.begin();
00158 for(; classiter!=classnames.end(); ++classiter){
00159 TString name(hname+"_Train_"+ classnames.at(icls)
00160 + "_prob_for_" + *classiter);
00161 TH1 *hist = (TH1*)titDir->Get(name);
00162 if (hist==0){
00163 cout << ":\t comparison histogram for overtraining check not available!" << endl;
00164 continue;
00165 }
00166 othists.Add(hist);
00167 }
00168
00169 TLegend *legend2= new TLegend( 1 - c->GetRightMargin() - 0.42, 1 - c->GetTopMargin() - 0.12,
00170 1 - c->GetRightMargin(), 1 - c->GetTopMargin() );
00171 legend2->SetFillStyle( 1 );
00172 legend2->SetBorderSize(1);
00173
00174 classiter = classnames.begin();
00175 for(Int_t i=0; i<othists.GetEntriesFast(); ++i, ++classiter){
00176 legend2->AddEntry(((TH1*)othists[i]),*classiter+" (training sample)","P");
00177 }
00178 legend2->SetMargin( 0.1 );
00179 legend2->Draw("same");
00180
00181
00182 for(Int_t i=0; i<othists.GetEntriesFast(); ++i){
00183 TMVAGlob::NormalizeHist((TH1*)othists[i] );
00184 if(((TH1*)othists[i])->GetMaximum() > histmax)
00185 histmax = ((TH1*)othists[i])->GetMaximum();
00186 }
00187
00188 TMVAGlob::SetMultiClassStyle( &othists );
00189 for(Int_t i=0; i<othists.GetEntriesFast(); ++i){
00190 Int_t col = ((TH1*)hists[i])->GetLineColor();
00191 ((TH1*)othists[i])->SetMarkerSize( 0.7 );
00192 ((TH1*)othists[i])->SetMarkerStyle( 20 );
00193 ((TH1*)othists[i])->SetMarkerColor( col );
00194 ((TH1*)othists[i])->SetLineWidth( 1 );
00195 ((TH1*)othists[i])->Draw("e1same");
00196 }
00197
00198 ymax = histmax*maxMult;
00199 frame->GetYaxis()->SetLimits( 0, ymax );
00200
00201
00202 TMVAGlob::SetMultiClassStyle( &othists );
00203 for(Int_t i=0; i<hists.GetEntriesFast(); ++i){
00204 ((TH1*)hists[i])->SetLineWidth( 1 );
00205 }
00206
00207
00208
00209
00210 cout << "--- Perform Kolmogorov-Smirnov tests" << endl;
00211 cout << "--- Goodness of consistency for class " << classnames.at(icls)<< endl;
00212
00213 for(Int_t j=0; j<othists.GetEntriesFast(); ++j){
00214 Float_t kol = ((TH1*)hists[j])->KolmogorovTest(((TH1*)othists[j]));
00215 cout << classnames.at(j) << ": " << kol << endl;
00216
00217 }
00218
00219
00220
00221
00222
00223
00224 }
00225
00226
00227
00228 frame->Draw("sameaxis");
00229
00230
00231
00232
00233
00234
00235
00236
00237
00238
00239
00240
00241
00242
00243
00244 c->Update();
00245
00246
00247
00248 TMVAGlob::plot_logo(1.058);
00249 if (Save_Images) {
00250 if (htype == MVAType) TMVAGlob::imgconv( c, Form("plots/mva_%s_%s",classnames.at(icls).Data(), methodTitle.Data()) );
00251 else if (htype == CompareType) TMVAGlob::imgconv( c, Form("plots/overtrain_%s_%s",classnames.at(icls).Data(), methodTitle.Data()) );
00252
00253 }
00254 countCanvas++;
00255 }
00256 }
00257 cout << "";
00258 }
00259 }
00260