00001 #include "TString.h"
00002 #include "TDirectory.h"
00003 #include "TH1F.h"
00004 #include "TFile.h"
00005 #include "TCanvas.h"
00006 #include "TLegend.h"
00007 #include "TROOT.h"
00008 #include "TKey.h"
00009 #include "TH2F.h"
00010 #include "TPad.h"
00011 #include "TObjArray.h"
00012 #include "TText.h"
00013
00014 #include "network.C"
00015
00016 void DrawNetworkMovie( TFile* file, const TString& methodType, const TString& methodTitle )
00017 {
00018
00019 TString dirname = methodType + "/" + methodTitle + "/" + "EpochMonitoring";
00020 TDirectory *epochDir = (TDirectory*)file->Get( dirname );
00021 if (!epochDir) {
00022 cout << "Big troubles: could not find directory \"" << dirname << "\"" << endl;
00023 exit(1);
00024 }
00025 epochDir->cd();
00026
00027
00028 TIter keyIt(epochDir->GetListOfKeys());
00029 TKey *key;
00030 vector<TString> epochList;
00031 Int_t ic = 0;
00032 while ((key = (TKey*)keyIt())) {
00033
00034 if (!gROOT->GetClass(key->GetClassName())->InheritsFrom("TH2F")) continue;
00035 TString name = key->GetName();
00036
00037 if (!name.BeginsWith("epochmonitoring___")) continue;
00038
00039
00040 TObjArray* tokens = name.Tokenize("_");
00041 TString es = ((TObjString*)tokens->At(2))->GetString();
00042
00043
00044 Bool_t isOld = kFALSE;
00045 for (vector<TString>::const_iterator it = epochList.begin(); it < epochList.end(); it++) {
00046 if (*it == es) isOld = kTRUE;
00047 }
00048 if (isOld) continue;
00049 epochList.push_back( es );
00050
00051
00052 TString bulkname = Form( "epochmonitoring___epoch_%s_weights_hist", es.Data() );
00053
00054
00055 if (ic <= 60) draw_network( file, epochDir, bulkname, kTRUE, es );
00056 ic++;
00057 }
00058 }
00059
00060
00061 void DrawMLPoutputMovie( TFile* file, const TString& methodType, const TString& methodTitle )
00062 {
00063 gROOT->SetBatch( 1 );
00064
00065
00066 const Int_t width = 600;
00067
00068
00069 TCanvas* c = 0;
00070
00071 Float_t nrms = 4;
00072 Float_t xmin = -1.2;
00073 Float_t xmax = 1.2;
00074 Float_t ymin = 0;
00075 Float_t ymax = 0;
00076 Float_t maxMult = 6.0;
00077 Int_t countCanvas = 0;
00078 Bool_t first = kTRUE;
00079
00080 TString dirname = methodType + "/" + methodTitle + "/" + "EpochMonitoring";
00081 TDirectory *epochDir = (TDirectory*)file->Get( dirname );
00082 if (!epochDir) {
00083 cout << "Big troubles: could not find directory \"" << dirname << "\"" << endl;
00084 exit(1);
00085 }
00086
00087
00088 TIter keyItTit(epochDir->GetListOfKeys());
00089 TKey *titkeyTit;
00090 while ((titkeyTit = (TKey*)keyItTit())) {
00091
00092 if (!gROOT->GetClass(titkeyTit->GetClassName())->InheritsFrom("TH1F")) continue;
00093 TString name = titkeyTit->GetName();
00094
00095 if (!name.BeginsWith("convergencetest___")) continue;
00096 if (!name.Contains("_train_")) continue;
00097 if (name.EndsWith( "_B")) continue;
00098
00099
00100 if (!name.EndsWith( "_S")) {
00101 cout << "Big troubles with histogram: " << name << " -> should end with _S" << endl;
00102 exit(1);
00103 }
00104
00105
00106 countCanvas++;
00107 TString ctitle = Form("TMVA response %s",methodTitle.Data());
00108 c = new TCanvas( Form("canvas%d", countCanvas), ctitle, 0, 0, width, (Int_t)width*0.78 );
00109
00110 TH1F* sig = (TH1F*)titkeyTit->ReadObj();
00111 sig->SetTitle( Form("TMVA response for classifier: %s", methodTitle.Data()) );
00112
00113 TString dataType = (name.Contains("_train_") ? "(training sample)" : "(test sample)");
00114
00115
00116 TString nbn = sig->GetName(); nbn[nbn.Length()-1] = 'B';
00117 TH1F* bgd = dynamic_cast<TH1F*>(epochDir->Get( nbn ));
00118 if (bgd == 0) {
00119 cout << "Big troubles with histogram: " << bgd << " -> cannot find!" << endl;
00120 exit(1);
00121 }
00122
00123 cout << "sig = " << sig->GetName() << endl;
00124 cout << "bgd = " << bgd->GetName() << endl;
00125
00126
00127 TMVAGlob::SetSignalAndBackgroundStyle( sig, bgd );
00128
00129
00130 TMVAGlob::NormalizeHists( sig, bgd );
00131
00132
00133 if (first) {
00134 if (xmin == 0 && xmax == 0) {
00135 xmin = TMath::Max( TMath::Min(sig->GetMean() - nrms*sig->GetRMS(),
00136 bgd->GetMean() - nrms*bgd->GetRMS() ),
00137 sig->GetXaxis()->GetXmin() );
00138 xmax = TMath::Min( TMath::Max(sig->GetMean() + nrms*sig->GetRMS(),
00139 bgd->GetMean() + nrms*bgd->GetRMS() ),
00140 sig->GetXaxis()->GetXmax() );
00141 }
00142 ymin = 0;
00143 ymax = TMath::Max( sig->GetMaximum(), bgd->GetMaximum() )*maxMult;
00144 first = kFALSE;
00145 }
00146
00147
00148 Int_t nb = 100;
00149 TString hFrameName(TString("frame") + methodTitle);
00150 TObject *o = gROOT->FindObject(hFrameName);
00151 if(o) delete o;
00152 TH2F* frame = new TH2F( hFrameName, sig->GetTitle(),
00153 nb, xmin, xmax, nb, ymin, ymax );
00154 frame->GetXaxis()->SetTitle( methodTitle + " response" );
00155 frame->GetYaxis()->SetTitle("(1/N) dN^{ }/^{ }dx");
00156 TMVAGlob::SetFrameStyle( frame );
00157
00158
00159 TObjArray* tokens = name.Tokenize("_");
00160 TString es = ((TObjString*)tokens->At(4))->GetString();
00161 if (!es.IsFloat()) {
00162 cout << "Big troubles in epoch parsing: \"" << es << "\" is not float" << endl;
00163 exit(1);
00164 }
00165 Int_t epoch = es.Atoi();
00166
00167
00168 frame->Draw();
00169
00170 c->GetPad(0)->SetLeftMargin( 0.105 );
00171 frame->GetYaxis()->SetTitleOffset( 1.2 );
00172
00173
00174 TLegend *legend= new TLegend( c->GetLeftMargin(), 1 - c->GetTopMargin() - 0.12,
00175 c->GetLeftMargin() + 0.5, 1 - c->GetTopMargin() );
00176 legend->SetFillStyle( 1 );
00177 legend->AddEntry(sig,TString("Signal ") + dataType, "F");
00178 legend->AddEntry(bgd,TString("Background ") + dataType, "F");
00179 legend->SetBorderSize(1);
00180 legend->SetMargin( 0.15 );
00181 legend->Draw("same");
00182
00183 TText* t = new TText();
00184 t->SetTextSize( 0.04 );
00185 t->SetTextColor( 1 );
00186 t->SetTextAlign( 31 );
00187 t->DrawTextNDC( 1 - c->GetRightMargin(), 1 - c->GetTopMargin() + 0.015, Form( "Epoch: %i", epoch) );
00188
00189
00190 sig->Draw("samehist");
00191 bgd->Draw("samehist");
00192
00193
00194 TString dirname = "movieplots";
00195 TString foutname = dirname + "/" + name;
00196 foutname.Resize( foutname.Length()-2 );
00197 foutname.ReplaceAll("convergencetest___","");
00198 foutname += ".gif";
00199
00200 cout << "storing file: " << foutname << endl;
00201
00202 c->Update();
00203 c->Print(foutname);
00204 }
00205 }
00206
00207
00208
00209 void MovieMaker( TString methodType = "Method_MLP", TString methodTitle = "MLP" )
00210 {
00211 TString fname = "TMVA.root";
00212 TFile* file = TMVAGlob::OpenFile( fname );
00213
00214
00215 DrawNetworkMovie( file, methodType, methodTitle );
00216 }
00217