mvasMulticlass.C

Go to the documentation of this file.
00001 #include "TLegend.h"
00002 #include "TText.h"
00003 #include "TH2.h"
00004 
00005 #include "tmvaglob.C"
00006 
00007 // this macro plots the resulting MVA distributions (Signal and
00008 // Background overlayed) of different MVA methods run in TMVA
00009 // (e.g. running TMVAnalysis.C).
00010 
00011 enum HistType { MVAType = 0, CompareType = 1 };
00012 
00013 // input: - Input file (result from TMVA)
00014 //        - use of TMVA plotting TStyle
00015 void mvasMulticlass( TString fin = "TMVAMulticlass.root", HistType htype = MVAType, Bool_t useTMVAStyle = kTRUE )
00016 {
00017    // set style and remove existing canvas'
00018    TMVAGlob::Initialize( useTMVAStyle );
00019 
00020    // switches
00021    const Bool_t Save_Images = kTRUE;
00022 
00023    // checks if file with name "fin" is already open, and if not opens one
00024    TFile* file = TMVAGlob::OpenFile( fin );  
00025 
00026    TDirectory* tempdir = (TDirectory*)file->Get("InputVariables_Id" );
00027    std::vector<TString> classnames(TMVAGlob::GetClassNames(tempdir));
00028 
00029    // define Canvas layout here!
00030    Int_t xPad = 1; // no of plots in x
00031    Int_t yPad = 1; // no of plots in y
00032    Int_t noPad = xPad * yPad ; 
00033    const Int_t width = 600;   // size of canvas
00034 
00035    // this defines how many canvases we need
00036    TCanvas *c = 0;
00037 
00038    // counter variables
00039    Int_t countCanvas = 0;
00040 
00041    // search for the right histograms in full list of keys
00042    TIter next(file->GetListOfKeys());
00043    TKey *key(0);   
00044    while ((key = (TKey*)next())) {
00045 
00046       if (!TString(key->GetName()).BeginsWith("Method_")) continue;
00047       if (!gROOT->GetClass(key->GetClassName())->InheritsFrom("TDirectory")) continue;
00048 
00049       TString methodName;
00050       TMVAGlob::GetMethodName(methodName,key);
00051 
00052       TDirectory* mDir = (TDirectory*)key->ReadObj();
00053 
00054       TIter keyIt(mDir->GetListOfKeys());
00055       TKey *titkey;
00056       while ((titkey = (TKey*)keyIt())) {
00057 
00058          if (!gROOT->GetClass(titkey->GetClassName())->InheritsFrom("TDirectory")) continue;
00059 
00060          TDirectory *titDir = (TDirectory *)titkey->ReadObj();
00061          TString methodTitle;
00062          TMVAGlob::GetMethodTitle(methodTitle,titDir);
00063 
00064          cout << "--- Found directory for method: " << methodName << "::" << methodTitle << endl;
00065          TString hname = "MVA_" + methodTitle;
00066           for(Int_t icls = 0; icls < classnames.size(); ++icls){
00067          TObjArray hists;
00068 
00069          std::vector<TString>::iterator classiter = classnames.begin();
00070          for(; classiter!=classnames.end(); ++classiter){
00071             TString name(hname+"_Test_"+ classnames.at(icls)
00072                          + "_prob_for_" + *classiter);
00073             TH1 *hist = (TH1*)titDir->Get(name);
00074             if (hist==0){
00075                cout << ":\t mva distribution not available (this is normal for Cut classifier)" << endl;
00076                continue;
00077             }
00078             hists.Add(hist);
00079          }
00080          
00081        
00082          // chop off useless stuff
00083          ((TH1*)hists.First())->SetTitle( Form("TMVA response for classifier: %s", methodTitle.Data() ));
00084            
00085          // create new canvas
00086          //cout << "Create canvas..." << endl;
00087          TString ctitle = ((htype == MVAType) ? 
00088                            Form("TMVA response for class %s %s", classnames.at(icls).Data(),methodTitle.Data()) :                
00089                            Form("TMVA comparison for class %s %s", classnames.at(icls).Data(),methodTitle.Data())) ;
00090          
00091          c = new TCanvas( Form("canvas%d", countCanvas+1), ctitle, 
00092                           countCanvas*50+200, countCanvas*20, width, (Int_t)width*0.78 ); 
00093     
00094          // set the histogram style
00095          //cout << "Set histogram style..." << endl;
00096          TMVAGlob::SetMultiClassStyle( &hists );
00097          
00098          // normalise all histograms and find maximum
00099          Float_t histmax = -1;
00100          for(Int_t i=0; i<hists.GetEntriesFast(); ++i){
00101             TMVAGlob::NormalizeHist((TH1*)hists[i] );
00102             if(((TH1*)hists[i])->GetMaximum() > histmax)
00103                histmax = ((TH1*)hists[i])->GetMaximum();
00104          }
00105          
00106          // frame limits (between 0 and 1 per definition)
00107          Float_t xmin = 0;
00108          Float_t xmax = 1;
00109          Float_t ymin = 0;
00110          Float_t maxMult = (htype == CompareType) ? 1.3 : 1.2;
00111          Float_t ymax = histmax*maxMult; 
00112          // build a frame
00113          Int_t nb = 500;
00114          TString hFrameName(TString("frame") + methodTitle);
00115          TObject *o = gROOT->FindObject(hFrameName);
00116          if(o) delete o;
00117          TH2F* frame = new TH2F( hFrameName, ((TH1*)hists.First())->GetTitle(), 
00118                                  nb, xmin, xmax, nb, ymin, ymax );
00119          frame->GetXaxis()->SetTitle( methodTitle + " response for "+classnames.at(icls));
00120           frame->GetYaxis()->SetTitle("(1/N) dN^{ }/^{ }dx");
00121          TMVAGlob::SetFrameStyle( frame );
00122    
00123          // eventually: draw the frame
00124          frame->Draw();  
00125     
00126          c->GetPad(0)->SetLeftMargin( 0.105 );
00127          frame->GetYaxis()->SetTitleOffset( 1.2 );
00128 
00129          // Draw legend               
00130          TLegend *legend= new TLegend( c->GetLeftMargin(), 1 - c->GetTopMargin() - 0.12, 
00131                                        c->GetLeftMargin() + (htype == CompareType ? 0.40 : 0.3), 1 - c->GetTopMargin() );
00132          legend->SetFillStyle( 1 );
00133          classiter = classnames.begin();
00134          
00135          for(Int_t i=0; i<hists.GetEntriesFast(); ++i, ++classiter){
00136             legend->AddEntry(((TH1*)hists[i]),*classiter,"F");
00137          }
00138          
00139          legend->SetBorderSize(1);
00140          legend->SetMargin( 0.3 );
00141          legend->Draw("same");
00142          
00143          
00144          for(Int_t i=0; i<hists.GetEntriesFast(); ++i){
00145 
00146             ((TH1*)hists[i])->Draw("histsame");
00147             TString ytit = TString("(1/N) ") + ((TH1*)hists[i])->GetYaxis()->GetTitle();
00148             ((TH1*)hists[i])->GetYaxis()->SetTitle( ytit ); // histograms are normalised
00149       
00150          }
00151        
00152          
00153          if (htype == CompareType) {
00154             
00155             TObjArray othists; 
00156             // if overtraining check, load additional histograms
00157             classiter = classnames.begin();
00158             for(; classiter!=classnames.end(); ++classiter){
00159                TString name(hname+"_Train_"+ classnames.at(icls)
00160                             + "_prob_for_" + *classiter);
00161                TH1 *hist = (TH1*)titDir->Get(name);
00162                if (hist==0){
00163                   cout << ":\t comparison histogram for overtraining check not available!" << endl;
00164                   continue;
00165                }
00166                othists.Add(hist);
00167             }
00168             
00169             TLegend *legend2= new TLegend( 1 - c->GetRightMargin() - 0.42, 1 - c->GetTopMargin() - 0.12,
00170                                            1 - c->GetRightMargin(), 1 - c->GetTopMargin() );
00171             legend2->SetFillStyle( 1 );
00172             legend2->SetBorderSize(1);
00173             
00174             classiter = classnames.begin();
00175             for(Int_t i=0; i<othists.GetEntriesFast(); ++i, ++classiter){
00176                legend2->AddEntry(((TH1*)othists[i]),*classiter+" (training sample)","P");
00177             }
00178             legend2->SetMargin( 0.1 );
00179             legend2->Draw("same");
00180             
00181             // normalise all histograms and get maximum
00182             for(Int_t i=0; i<othists.GetEntriesFast(); ++i){
00183                TMVAGlob::NormalizeHist((TH1*)othists[i] );
00184                if(((TH1*)othists[i])->GetMaximum() > histmax)
00185                   histmax = ((TH1*)othists[i])->GetMaximum();
00186             }
00187 
00188             TMVAGlob::SetMultiClassStyle( &othists );
00189             for(Int_t i=0; i<othists.GetEntriesFast(); ++i){
00190                Int_t col = ((TH1*)hists[i])->GetLineColor();
00191                ((TH1*)othists[i])->SetMarkerSize( 0.7 );
00192                ((TH1*)othists[i])->SetMarkerStyle( 20 );
00193                ((TH1*)othists[i])->SetMarkerColor( col );
00194                ((TH1*)othists[i])->SetLineWidth( 1 );
00195                ((TH1*)othists[i])->Draw("e1same");
00196             }
00197             
00198             ymax = histmax*maxMult;
00199             frame->GetYaxis()->SetLimits( 0, ymax );
00200       
00201             // for better visibility, plot thinner lines
00202             TMVAGlob::SetMultiClassStyle( &othists );
00203             for(Int_t i=0; i<hists.GetEntriesFast(); ++i){
00204                 ((TH1*)hists[i])->SetLineWidth( 1 );
00205             }
00206             
00207             
00208             // perform K-S test
00209             
00210             cout << "--- Perform Kolmogorov-Smirnov tests" << endl;
00211             cout << "--- Goodness of consistency for class " << classnames.at(icls)<< endl;
00212             //TString probatext("Kolmogorov-Smirnov test: ");
00213             for(Int_t j=0; j<othists.GetEntriesFast(); ++j){
00214                Float_t kol = ((TH1*)hists[j])->KolmogorovTest(((TH1*)othists[j]));
00215                cout <<  classnames.at(j) << ": " << kol  << endl;
00216                //probatext.Append(classnames.at(j)+Form(" %.3f ",kol));
00217             }
00218             
00219            
00220            
00221             //TText* tt = new TText( 0.12, 0.74, probatext );
00222             //tt->SetNDC(); tt->SetTextSize( 0.032 ); tt->AppendPad();
00223             
00224          }
00225          
00226          
00227          // redraw axes
00228          frame->Draw("sameaxis");
00229 
00230          // text for overflows
00231          //Int_t    nbin = sig->GetNbinsX();
00232          //Double_t dxu  = sig->GetBinWidth(0);
00233          //Double_t dxo  = sig->GetBinWidth(nbin+1);
00234          //TString uoflow = Form( "U/O-flow (S,B): (%.1f, %.1f)%% / (%.1f, %.1f)%%", 
00235          //                       sig->GetBinContent(0)*dxu*100, bgd->GetBinContent(0)*dxu*100,
00236          //                      sig->GetBinContent(nbin+1)*dxo*100, bgd->GetBinContent(nbin+1)*dxo*100 );
00237       //TText* t = new TText( 0.975, 0.115, uoflow );
00238          //t->SetNDC();
00239          //t->SetTextSize( 0.030 );
00240          //t->SetTextAngle( 90 );
00241          //t->AppendPad();    
00242    
00243          // update canvas
00244          c->Update();
00245 
00246          // save canvas to file
00247 
00248          TMVAGlob::plot_logo(1.058);
00249          if (Save_Images) {
00250             if      (htype == MVAType)     TMVAGlob::imgconv( c, Form("plots/mva_%s_%s",classnames.at(icls).Data(), methodTitle.Data()) );
00251             else if      (htype == CompareType)     TMVAGlob::imgconv( c, Form("plots/overtrain_%s_%s",classnames.at(icls).Data(), methodTitle.Data()) );
00252 
00253          }
00254          countCanvas++;
00255          }
00256       }
00257       cout << "";
00258    }
00259 }
00260 

Generated on Tue Jul 5 15:25:45 2011 for ROOT_528-00b_version by  doxygen 1.5.1