00001
00002
00003 #include "RooGlobalFunc.h"
00004 #include <stdlib.h>
00005 #include "TMatrixDSym.h"
00006 #include "RooMultiVarGaussian.h"
00007 #include "RooArgList.h"
00008 #include "RooRealVar.h"
00009 #include "TH2F.h"
00010 #include "TCanvas.h"
00011 #include "RooAbsReal.h"
00012 #include "RooFitResult.h"
00013 #include "TStopwatch.h"
00014 #include "RooStats/MCMCCalculator.h"
00015 #include "RooStats/MetropolisHastings.h"
00016 #include "RooStats/MarkovChain.h"
00017 #include "RooStats/ConfInterval.h"
00018 #include "RooStats/MCMCInterval.h"
00019 #include "RooStats/MCMCIntervalPlot.h"
00020 #include "RooStats/ModelConfig.h"
00021 #include "RooStats/ProposalHelper.h"
00022 #include "RooStats/ProposalFunction.h"
00023 #include "RooStats/PdfProposal.h"
00024 #include "RooStats/ProfileLikelihoodCalculator.h"
00025 #include "RooStats/LikelihoodIntervalPlot.h"
00026 #include "RooStats/LikelihoodInterval.h"
00027
00028 using namespace std;
00029 using namespace RooFit;
00030 using namespace RooStats;
00031
00032
00033 void MultivariateGaussianTest(Int_t dim = 4, Int_t nPOI = 2)
00034 {
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048
00049
00050
00051
00052
00053
00054
00055 TStopwatch t;
00056 t.Start();
00057
00058 RooArgList xVec;
00059 RooArgList muVec;
00060 RooArgSet poi;
00061
00062
00063 Int_t i,j;
00064 RooRealVar* x;
00065 RooRealVar* mu_x;
00066 for (i = 0; i < dim; i++) {
00067 char* name = Form("x%d", i);
00068 x = new RooRealVar(name, name, 0, -3,3);
00069 xVec.add(*x);
00070
00071 char* mu_name = Form("mu_x%d",i);
00072 mu_x = new RooRealVar(mu_name, mu_name, 0, -2,2);
00073 muVec.add(*mu_x);
00074 }
00075
00076
00077 for (i = 0; i < nPOI; i++) {
00078 poi.add(*muVec.at(i));
00079 }
00080
00081
00082 TMatrixDSym cov(dim);
00083 for (i = 0; i < dim; i++) {
00084 for (j = 0; j < dim; j++) {
00085 if (i == j) cov(i,j) = 3.;
00086 else cov(i,j) = 1.0;
00087 }
00088 }
00089
00090
00091 RooMultiVarGaussian mvg("mvg", "mvg", xVec, muVec, cov);
00092
00093
00094
00095 RooDataSet* data = mvg.generate(xVec, 100);
00096
00097
00098
00099 RooWorkspace* w = new RooWorkspace("MVG");
00100 ModelConfig modelConfig(w);
00101 modelConfig.SetPdf(mvg);
00102 modelConfig.SetParametersOfInterest(poi);
00103
00104
00105
00106
00107
00108
00109
00110 RooFitResult* fit = mvg.fitTo(*data, Save(true));
00111 ProposalHelper ph;
00112 ph.SetVariables((RooArgSet&)fit->floatParsFinal());
00113 ph.SetCovMatrix(fit->covarianceMatrix());
00114 ph.SetUpdateProposalParameters(true);
00115 ph.SetCacheSize(100);
00116 ProposalFunction* pdfProp = ph.GetProposalFunction();
00117
00118
00119 MCMCCalculator mc(*data, modelConfig);
00120 mc.SetConfidenceLevel(0.95);
00121 mc.SetNumBurnInSteps(100);
00122 mc.SetNumIters(10000);
00123 mc.SetNumBins(50);
00124 mc.SetProposalFunction(*pdfProp);
00125
00126 MCMCInterval* mcInt = mc.GetInterval();
00127 RooArgList* poiList = mcInt->GetAxes();
00128
00129
00130 ProfileLikelihoodCalculator plc(*data, modelConfig);
00131 plc.SetConfidenceLevel(0.95);
00132 LikelihoodInterval* plInt = (LikelihoodInterval*)plc.GetInterval();
00133
00134
00135
00136
00137 MCMCIntervalPlot mcPlot(*mcInt);
00138
00139 TCanvas* c1 = new TCanvas();
00140 mcPlot.SetLineColor(kGreen);
00141 mcPlot.SetLineWidth(2);
00142 mcPlot.Draw();
00143
00144 LikelihoodIntervalPlot plPlot(plInt);
00145 plPlot.Draw("same");
00146
00147 if (poiList->getSize() == 1) {
00148 RooRealVar* p = (RooRealVar*)poiList->at(0);
00149 Double_t ll = mcInt->LowerLimit(*p);
00150 Double_t ul = mcInt->UpperLimit(*p);
00151 cout << "MCMC interval: [" << ll << ", " << ul << "]" << endl;
00152 }
00153
00154 if (poiList->getSize() == 2) {
00155 RooRealVar* p0 = (RooRealVar*)poiList->at(0);
00156 RooRealVar* p1 = (RooRealVar*)poiList->at(1);
00157 TCanvas* scatter = new TCanvas();
00158 Double_t ll = mcInt->LowerLimit(*p0);
00159 Double_t ul = mcInt->UpperLimit(*p0);
00160 cout << "MCMC interval on p0: [" << ll << ", " << ul << "]" << endl;
00161 ll = mcInt->LowerLimit(*p0);
00162 ul = mcInt->UpperLimit(*p0);
00163 cout << "MCMC interval on p1: [" << ll << ", " << ul << "]" << endl;
00164
00165
00166
00167
00168 mcPlot.DrawChainScatter(*p0, *p1);
00169 scatter->Update();
00170 }
00171
00172 t.Print();
00173
00174 }
00175