mvas.C

Go to the documentation of this file.
00001 #include "TLegend.h"
00002 #include "TText.h"
00003 #include "TH2.h"
00004 
00005 #include "tmvaglob.C"
00006 
00007 // this macro plots the resulting MVA distributions (Signal and
00008 // Background overlayed) of different MVA methods run in TMVA
00009 // (e.g. running TMVAnalysis.C).
00010 
00011 enum HistType { MVAType = 0, ProbaType = 1, RarityType = 2, CompareType = 3 };
00012 
00013 // input: - Input file (result from TMVA)
00014 //        - use of TMVA plotting TStyle
00015 void mvas( TString fin = "TMVA.root", HistType htype = MVAType, Bool_t useTMVAStyle = kTRUE )
00016 {
00017    // set style and remove existing canvas'
00018    TMVAGlob::Initialize( useTMVAStyle );
00019 
00020    // switches
00021    const Bool_t Save_Images = kTRUE;
00022 
00023    // checks if file with name "fin" is already open, and if not opens one
00024    TFile* file = TMVAGlob::OpenFile( fin );  
00025 
00026    // define Canvas layout here!
00027    Int_t xPad = 1; // no of plots in x
00028    Int_t yPad = 1; // no of plots in y
00029    Int_t noPad = xPad * yPad ; 
00030    const Int_t width = 600;   // size of canvas
00031 
00032    // this defines how many canvases we need
00033    TCanvas *c = 0;
00034 
00035    // counter variables
00036    Int_t countCanvas = 0;
00037 
00038    // search for the right histograms in full list of keys
00039    TIter next(file->GetListOfKeys());
00040    TKey *key(0);   
00041    while ((key = (TKey*)next())) {
00042 
00043       if (!TString(key->GetName()).BeginsWith("Method_")) continue;
00044       if (!gROOT->GetClass(key->GetClassName())->InheritsFrom("TDirectory")) continue;
00045 
00046       TString methodName;
00047       TMVAGlob::GetMethodName(methodName,key);
00048 
00049       TDirectory* mDir = (TDirectory*)key->ReadObj();
00050 
00051       TIter keyIt(mDir->GetListOfKeys());
00052       TKey *titkey;
00053       while ((titkey = (TKey*)keyIt())) {
00054 
00055          if (!gROOT->GetClass(titkey->GetClassName())->InheritsFrom("TDirectory")) continue;
00056 
00057          TDirectory *titDir = (TDirectory *)titkey->ReadObj();
00058          TString methodTitle;
00059          TMVAGlob::GetMethodTitle(methodTitle,titDir);
00060 
00061          cout << "--- Found directory for method: " << methodName << "::" << methodTitle << flush;
00062          TString hname = "MVA_" + methodTitle;
00063          if      (htype == ProbaType  ) hname += "_Proba";
00064          else if (htype == RarityType ) hname += "_Rarity";
00065          TH1* sig = dynamic_cast<TH1*>(titDir->Get( hname + "_S" ));
00066          TH1* bgd = dynamic_cast<TH1*>(titDir->Get( hname + "_B" ));
00067 
00068          if (sig==0 || bgd==0) {
00069             if     (htype == MVAType)     
00070                cout << ":\t mva distribution not available (this is normal for Cut classifier)" << endl;
00071             else if(htype == ProbaType)   
00072                cout << ":\t probability distribution not available" << endl;
00073             else if(htype == RarityType)  
00074                cout << ":\t rarity distribution not available" << endl;
00075             else if(htype == CompareType) 
00076                cout << ":\t overtraining check not available" << endl;
00077             else cout << endl;
00078             continue;
00079          }
00080 
00081          cout << " containing " << hname << "_S/_B" << endl;
00082          // chop off useless stuff
00083          sig->SetTitle( Form("TMVA response for classifier: %s", methodTitle.Data()) );
00084          if      (htype == ProbaType) 
00085             sig->SetTitle( Form("TMVA probability for classifier: %s", methodTitle.Data()) );
00086          else if (htype == RarityType) 
00087             sig->SetTitle( Form("TMVA Rarity for classifier: %s", methodTitle.Data()) );
00088          else if (htype == CompareType) 
00089             sig->SetTitle( Form("TMVA overtraining check for classifier: %s", methodTitle.Data()) );
00090          
00091          // create new canvas
00092          TString ctitle = ((htype == MVAType) ? 
00093                            Form("TMVA response %s",methodTitle.Data()) : 
00094                            (htype == ProbaType) ? 
00095                            Form("TMVA probability %s",methodTitle.Data()) :
00096                            (htype == CompareType) ? 
00097                            Form("TMVA comparison %s",methodTitle.Data()) :
00098                            Form("TMVA Rarity %s",methodTitle.Data()));
00099          
00100          c = new TCanvas( Form("canvas%d", countCanvas+1), ctitle, 
00101                           countCanvas*50+200, countCanvas*20, width, (Int_t)width*0.78 ); 
00102     
00103          // set the histogram style
00104          TMVAGlob::SetSignalAndBackgroundStyle( sig, bgd );
00105          
00106          // normalise both signal and background
00107          TMVAGlob::NormalizeHists( sig, bgd );
00108          
00109          // frame limits (choose judicuous x range)
00110          Float_t nrms = 4;
00111          cout << "--- Mean and RMS (S): " << sig->GetMean() << ", " << sig->GetRMS() << endl;
00112          cout << "--- Mean and RMS (B): " << bgd->GetMean() << ", " << bgd->GetRMS() << endl;
00113          Float_t xmin = TMath::Max( TMath::Min(sig->GetMean() - nrms*sig->GetRMS(), 
00114                                                bgd->GetMean() - nrms*bgd->GetRMS() ),
00115                                     sig->GetXaxis()->GetXmin() );
00116          Float_t xmax = TMath::Min( TMath::Max(sig->GetMean() + nrms*sig->GetRMS(), 
00117                                                bgd->GetMean() + nrms*bgd->GetRMS() ),
00118                                     sig->GetXaxis()->GetXmax() );
00119          Float_t ymin = 0;
00120          Float_t maxMult = (htype == CompareType) ? 1.3 : 1.2;
00121          Float_t ymax = TMath::Max( sig->GetMaximum(), bgd->GetMaximum() )*maxMult;
00122    
00123          // build a frame
00124          Int_t nb = 500;
00125          TString hFrameName(TString("frame") + methodTitle);
00126          TObject *o = gROOT->FindObject(hFrameName);
00127          if(o) delete o;
00128          TH2F* frame = new TH2F( hFrameName, sig->GetTitle(), 
00129                                  nb, xmin, xmax, nb, ymin, ymax );
00130          frame->GetXaxis()->SetTitle( methodTitle + ((htype == MVAType || htype == CompareType) ? " response" : "") );
00131          if      (htype == ProbaType  ) frame->GetXaxis()->SetTitle( "Signal probability" );
00132          else if (htype == RarityType ) frame->GetXaxis()->SetTitle( "Signal rarity" );
00133          frame->GetYaxis()->SetTitle("(1/N) dN^{ }/^{ }dx");
00134          TMVAGlob::SetFrameStyle( frame );
00135    
00136          // eventually: draw the frame
00137          frame->Draw();  
00138     
00139          c->GetPad(0)->SetLeftMargin( 0.105 );
00140          frame->GetYaxis()->SetTitleOffset( 1.2 );
00141 
00142          // Draw legend               
00143          TLegend *legend= new TLegend( c->GetLeftMargin(), 1 - c->GetTopMargin() - 0.12, 
00144                                        c->GetLeftMargin() + (htype == CompareType ? 0.40 : 0.3), 1 - c->GetTopMargin() );
00145          legend->SetFillStyle( 1 );
00146          legend->AddEntry(sig,TString("Signal")     + ((htype == CompareType) ? " (test sample)" : ""), "F");
00147          legend->AddEntry(bgd,TString("Background") + ((htype == CompareType) ? " (test sample)" : ""), "F");
00148          legend->SetBorderSize(1);
00149          legend->SetMargin( (htype == CompareType ? 0.2 : 0.3) );
00150          legend->Draw("same");
00151 
00152          // overlay signal and background histograms
00153          sig->Draw("samehist");
00154          bgd->Draw("samehist");
00155    
00156          if (htype == CompareType) {
00157             // if overtraining check, load additional histograms
00158             TH1* sigOv = 0;
00159             TH1* bgdOv = 0;
00160 
00161             TString ovname = hname += "_Train";
00162             sigOv = dynamic_cast<TH1*>(titDir->Get( ovname + "_S" ));
00163             bgdOv = dynamic_cast<TH1*>(titDir->Get( ovname + "_B" ));
00164       
00165             if (sigOv == 0 || bgdOv == 0) {
00166                cout << "+++ Problem in \"mvas.C\": overtraining check histograms do not exist" << endl;
00167             }
00168             else {
00169                cout << "--- Found comparison histograms for overtraining check" << endl;
00170 
00171                TLegend *legend2= new TLegend( 1 - c->GetRightMargin() - 0.42, 1 - c->GetTopMargin() - 0.12,
00172                                               1 - c->GetRightMargin(), 1 - c->GetTopMargin() );
00173                legend2->SetFillStyle( 1 );
00174                legend2->SetBorderSize(1);
00175                legend2->AddEntry(sigOv,"Signal (training sample)","P");
00176                legend2->AddEntry(bgdOv,"Background (training sample)","P");
00177                legend2->SetMargin( 0.1 );
00178                legend2->Draw("same");
00179             }
00180             // normalise both signal and background
00181             TMVAGlob::NormalizeHists( sigOv, bgdOv );
00182 
00183             Int_t col = sig->GetLineColor();
00184             sigOv->SetMarkerColor( col );
00185             sigOv->SetMarkerSize( 0.7 );
00186             sigOv->SetMarkerStyle( 20 );
00187             sigOv->SetLineWidth( 1 );
00188             sigOv->SetLineColor( col );
00189             sigOv->Draw("e1same");
00190       
00191             col = bgd->GetLineColor();
00192             bgdOv->SetMarkerColor( col );
00193             bgdOv->SetMarkerSize( 0.7 );
00194             bgdOv->SetMarkerStyle( 20 );
00195             bgdOv->SetLineWidth( 1 );
00196             bgdOv->SetLineColor( col );
00197             bgdOv->Draw("e1same");
00198 
00199             ymax = TMath::Max( ymax, TMath::Max( sigOv->GetMaximum(), bgdOv->GetMaximum() )*maxMult );
00200             frame->GetYaxis()->SetLimits( 0, ymax );
00201       
00202             // for better visibility, plot thinner lines
00203             sig->SetLineWidth( 1 );
00204             bgd->SetLineWidth( 1 );
00205 
00206             // perform K-S test
00207             cout << "--- Perform Kolmogorov-Smirnov tests" << endl;
00208             Double_t kolS = sig->KolmogorovTest( sigOv );
00209             Double_t kolB = bgd->KolmogorovTest( bgdOv );
00210             cout << "--- Goodness of signal (background) consistency: " << kolS << " (" << kolB << ")" << endl;
00211 
00212             TString probatext = Form( "Kolmogorov-Smirnov test: signal (background) probability = %5.3g (%5.3g)", kolS, kolB );
00213             TText* tt = new TText( 0.12, 0.74, probatext );
00214             tt->SetNDC(); tt->SetTextSize( 0.032 ); tt->AppendPad(); 
00215          }
00216 
00217          // redraw axes
00218          frame->Draw("sameaxis");
00219 
00220          // text for overflows
00221          Int_t    nbin = sig->GetNbinsX();
00222          Double_t dxu  = sig->GetBinWidth(0);
00223          Double_t dxo  = sig->GetBinWidth(nbin+1);
00224          TString uoflow = Form( "U/O-flow (S,B): (%.1f, %.1f)%% / (%.1f, %.1f)%%", 
00225                                 sig->GetBinContent(0)*dxu*100, bgd->GetBinContent(0)*dxu*100,
00226                                 sig->GetBinContent(nbin+1)*dxo*100, bgd->GetBinContent(nbin+1)*dxo*100 );
00227          TText* t = new TText( 0.975, 0.115, uoflow );
00228          t->SetNDC();
00229          t->SetTextSize( 0.030 );
00230          t->SetTextAngle( 90 );
00231          t->AppendPad();    
00232    
00233          // update canvas
00234          c->Update();
00235 
00236          // save canvas to file
00237 
00238          TMVAGlob::plot_logo(1.058);
00239          if (Save_Images) {
00240             if      (htype == MVAType)     TMVAGlob::imgconv( c, Form("plots/mva_%s",     methodTitle.Data()) );
00241             else if (htype == ProbaType)   TMVAGlob::imgconv( c, Form("plots/proba_%s",   methodTitle.Data()) ); 
00242             else if (htype == CompareType) TMVAGlob::imgconv( c, Form("plots/overtrain_%s", methodTitle.Data()) ); 
00243             else                           TMVAGlob::imgconv( c, Form("plots/rarity_%s",  methodTitle.Data()) ); 
00244          }
00245          countCanvas++;
00246          
00247       }
00248       cout << "";
00249    }
00250 }
00251 

Generated on Tue Jul 5 15:25:45 2011 for ROOT_528-00b_version by  doxygen 1.5.1