00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024
00025
00026
00027
00028
00029
00030
00031
00032
00033
00034
00035
00036
00037
00038
00039
00040
00041
00042
00043
00044
00045
00046
00047
00048 #ifndef RooStats_RooStatsUtils
00049 #include "RooStats/RooStatsUtils.h"
00050 #endif
00051 #ifndef ROOT_Rtypes
00052 #include "Rtypes.h"
00053 #endif
00054 #ifndef ROO_REAL_VAR
00055 #include "RooRealVar.h"
00056 #endif
00057 #ifndef ROO_NLL_VAR
00058 #include "RooNLLVar.h"
00059 #endif
00060 #ifndef ROO_GLOBAL_FUNC
00061 #include "RooGlobalFunc.h"
00062 #endif
00063 #ifndef ROO_DATA_SET
00064 #include "RooDataSet.h"
00065 #endif
00066 #ifndef ROO_ARG_SET
00067 #include "RooArgSet.h"
00068 #endif
00069 #ifndef ROO_ARG_LIST
00070 #include "RooArgList.h"
00071 #endif
00072 #ifndef ROO_MSG_SERVICE
00073 #include "RooMsgService.h"
00074 #endif
00075 #ifndef ROO_RANDOM
00076 #include "RooRandom.h"
00077 #endif
00078 #ifndef ROOT_TH1
00079 #include "TH1.h"
00080 #endif
00081 #ifndef ROOT_TMath
00082 #include "TMath.h"
00083 #endif
00084 #ifndef ROOT_TFile
00085 #include "TFile.h"
00086 #endif
00087 #ifndef ROOSTATS_MetropolisHastings
00088 #include "RooStats/MetropolisHastings.h"
00089 #endif
00090 #ifndef ROOSTATS_MarkovChain
00091 #include "RooStats/MarkovChain.h"
00092 #endif
00093 #ifndef RooStats_MCMCInterval
00094 #include "RooStats/MCMCInterval.h"
00095 #endif
00096
00097 ClassImp(RooStats::MetropolisHastings);
00098
00099 using namespace RooFit;
00100 using namespace RooStats;
00101
00102 MetropolisHastings::MetropolisHastings()
00103 {
00104
00105 fFunction = NULL;
00106 fParameters = NULL;
00107 fPropFunc = NULL;
00108 fNumIters = 0;
00109 fNumBurnInSteps = 0;
00110 fSign = kSignUnset;
00111 fType = kTypeUnset;
00112 }
00113
00114 MetropolisHastings::MetropolisHastings(RooAbsReal& function, RooArgSet& paramsOfInterest,
00115 ProposalFunction& proposalFunction, Int_t numIters)
00116 {
00117 fFunction = &function;
00118 SetParameters(paramsOfInterest);
00119 SetProposalFunction(proposalFunction);
00120 fNumIters = numIters;
00121 fNumBurnInSteps = 0;
00122 fSign = kSignUnset;
00123 fType = kTypeUnset;
00124 }
00125
00126 MarkovChain* MetropolisHastings::ConstructChain()
00127 {
00128 if (!fParameters || !fPropFunc || !fFunction) {
00129 coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
00130 " function, or (log) likelihood function" << endl;
00131 return NULL;
00132 }
00133 if (fSign == kSignUnset || fType == kTypeUnset) {
00134 coutE(Eval) << "Please set type and sign of your function using "
00135 << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
00136 endl;
00137 return NULL;
00138 }
00139
00140 RooArgSet x;
00141 RooArgSet xPrime;
00142 x.addClone(*fParameters);
00143 RandomizeCollection(x);
00144 xPrime.addClone(*fParameters);
00145 RandomizeCollection(xPrime);
00146
00147 MarkovChain* chain = new MarkovChain();
00148 chain->SetParameters(*fParameters);
00149
00150 Int_t weight = 0;
00151 Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
00152
00153 RooFit::MsgLevel oldMsgLevel = RooMsgService::instance().globalKillBelow();
00154 RooMsgService::instance().setGlobalKillBelow(RooFit::ERROR);
00155
00156
00157
00158
00159 if (fType == kLog)
00160 fFunction->setEvalErrorLoggingMode(RooAbsReal::CountErrors);
00161
00162 bool hadEvalError = true;
00163
00164 Int_t i = 0;
00165
00166
00167
00168
00169
00170
00171
00172 while (i < 1000 && hadEvalError) {
00173 RandomizeCollection(x);
00174 RooStats::SetParameters(&x, fParameters);
00175 xL = fFunction->getVal();
00176
00177 if (fType == kLog) {
00178 if (fFunction->numEvalErrors() > 0) {
00179 fFunction->clearEvalErrorLog();
00180 hadEvalError = true;
00181 } else
00182 hadEvalError = false;
00183 } else if (fType == kRegular) {
00184 if (xL == 0.0)
00185 hadEvalError = true;
00186 else
00187 hadEvalError = false;
00188 } else
00189
00190 hadEvalError = false;
00191 }
00192
00193 if(hadEvalError) {
00194 coutE(Eval) << "Problem finding a good starting point in " <<
00195 "MetropolisHastings::ConstructChain() " << endl;
00196 }
00197
00198
00199 for (i = 0; i < fNumIters; i++) {
00200
00201 hadEvalError = false;
00202
00203 if (i % (fNumIters / 100) == 0) {
00204
00205 fprintf(stdout, ".");
00206 fflush(NULL);
00207 }
00208
00209 fPropFunc->Propose(xPrime, x);
00210
00211 RooStats::SetParameters(&xPrime, fParameters);
00212 xPrimeL = fFunction->getVal();
00213
00214
00215 if (fFunction->numEvalErrors() > 0 && fType == kLog) {
00216 xPrimeL = RooNumber::infinity();
00217 fFunction->clearEvalErrorLog();
00218 hadEvalError = true;
00219 }
00220
00221
00222
00223
00224
00225
00226 if (fType == kLog) {
00227 if (fSign == kPositive)
00228 a = xL - xPrimeL;
00229 else
00230 a = xPrimeL - xL;
00231 }
00232 else
00233 a = xPrimeL / xL;
00234
00235
00236 if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
00237 Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
00238 Double_t xPD = fPropFunc->GetProposalDensity(x, xPrime);
00239 if (fType == kRegular)
00240 a *= xPD / xPrimePD;
00241 else
00242 a += TMath::Log(xPrimePD) - TMath::Log(xPD);
00243 }
00244
00245 if (!hadEvalError && ShouldTakeStep(a)) {
00246
00247
00248
00249 if (weight != 0.0)
00250 chain->Add(x, CalcNLL(xL), (Double_t)weight);
00251
00252
00253 weight = 1;
00254 RooStats::SetParameters(&xPrime, &x);
00255 xL = xPrimeL;
00256 } else {
00257
00258 weight++;
00259 }
00260 }
00261
00262
00263 if (weight != 0.0)
00264 chain->Add(x, CalcNLL(xL), (Double_t)weight);
00265 printf("\n");
00266
00267 RooMsgService::instance().setGlobalKillBelow(oldMsgLevel);
00268
00269 Int_t numAccepted = chain->Size();
00270 coutI(Eval) << "Proposal acceptance rate: " <<
00271 numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
00272 coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
00273
00274
00275
00276
00277
00278 return chain;
00279 }
00280
00281 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
00282 {
00283 if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
00284
00285
00286 return kTRUE;
00287 }
00288 else {
00289
00290
00291
00292 Double_t rand = RooRandom::uniform();
00293 if (fType == kLog) {
00294 rand = TMath::Log(rand);
00295
00296
00297 if (-1.0 * rand >= a)
00298
00299
00300 return kTRUE;
00301 } else {
00302
00303
00304
00305 if (rand < a)
00306
00307
00308 return kTRUE;
00309 }
00310 return kFALSE;
00311 }
00312 }
00313
00314 Double_t MetropolisHastings::CalcNLL(Double_t xL)
00315 {
00316 if (fType == kLog) {
00317 if (fSign == kNegative)
00318 return xL;
00319 else
00320 return -xL;
00321 } else {
00322 if (fSign == kPositive)
00323 return -1.0 * TMath::Log(xL);
00324 else
00325 return -1.0 * TMath::Log(-xL);
00326 }
00327 }