MetropolisHastings.cxx

Go to the documentation of this file.
00001 // @(#)root/roostats:$Id: MetropolisHastings.cxx 34109 2010-06-24 15:00:16Z moneta $
00002 // Authors: Kevin Belasco        17/06/2009
00003 // Authors: Kyle Cranmer         17/06/2009
00004 /*************************************************************************
00005  * Copyright (C) 1995-2008, Rene Brun and Fons Rademakers.               *
00006  * All rights reserved.                                                  *
00007  *                                                                       *
00008  * For the licensing terms see $ROOTSYS/LICENSE.                         *
00009  * For the list of contributors see $ROOTSYS/README/CREDITS.             *
00010  *************************************************************************/
00011 
00012 //_________________________________________________
00013 /*
00014 BEGIN_HTML
00015 <p>
00016 This class uses the Metropolis-Hastings algorithm to construct a Markov Chain
00017 of data points using Monte Carlo. In the main algorithm, new points in the
00018 parameter space are proposed and then visited based on their relative
00019 likelihoods.  This class can use any implementation of the ProposalFunction,
00020 including non-symmetric proposal functions, to propose parameter points and
00021 still maintain detailed balance when constructing the chain.
00022 </p>
00023 
00024 <p>
00025 The "Likelihood" function that is sampled when deciding what steps to take in
00026 the chain has been given a very generic implementation.  The user can create
00027 any RooAbsReal based on the parameters and pass it to a MetropolisHastings
00028 object with the method SetFunction(RooAbsReal&).  Be sure to tell
00029 MetropolisHastings whether your RooAbsReal is on a (+/-) regular or log scale,
00030 so that it knows what logic to use when sampling your RooAbsReal.  For example,
00031 a common use is to sample from a -log(Likelihood) distribution (NLL), for which
00032 the appropriate configuration calls are SetType(MetropolisHastings::kLog);
00033 SetSign(MetropolisHastings::kNegative);
00034 If you're using a traditional likelihood function:
00035 SetType(MetropolisHastings::kRegular);  SetSign(MetropolisHastings::kPositive);
00036 You must set these type and sign flags or MetropolisHastings will not construct
00037 a MarkovChain.
00038 </p>
00039 
00040 <p>
00041 Also note that in ConstructChain(), the values of the variables are randomized
00042 uniformly over their intervals before construction of the MarkovChain begins.
00043 </p>
00044 END_HTML
00045 */
00046 //_________________________________________________
00047 
00048 #ifndef RooStats_RooStatsUtils
00049 #include "RooStats/RooStatsUtils.h"
00050 #endif
00051 #ifndef ROOT_Rtypes
00052 #include "Rtypes.h"
00053 #endif
00054 #ifndef ROO_REAL_VAR
00055 #include "RooRealVar.h"
00056 #endif
00057 #ifndef ROO_NLL_VAR
00058 #include "RooNLLVar.h"
00059 #endif
00060 #ifndef ROO_GLOBAL_FUNC
00061 #include "RooGlobalFunc.h"
00062 #endif
00063 #ifndef ROO_DATA_SET
00064 #include "RooDataSet.h"
00065 #endif
00066 #ifndef ROO_ARG_SET
00067 #include "RooArgSet.h"
00068 #endif
00069 #ifndef ROO_ARG_LIST
00070 #include "RooArgList.h"
00071 #endif
00072 #ifndef ROO_MSG_SERVICE
00073 #include "RooMsgService.h"
00074 #endif
00075 #ifndef ROO_RANDOM
00076 #include "RooRandom.h"
00077 #endif
00078 #ifndef ROOT_TH1
00079 #include "TH1.h"
00080 #endif
00081 #ifndef ROOT_TMath
00082 #include "TMath.h"
00083 #endif
00084 #ifndef ROOT_TFile
00085 #include "TFile.h"
00086 #endif
00087 #ifndef ROOSTATS_MetropolisHastings
00088 #include "RooStats/MetropolisHastings.h"
00089 #endif
00090 #ifndef ROOSTATS_MarkovChain
00091 #include "RooStats/MarkovChain.h"
00092 #endif
00093 #ifndef RooStats_MCMCInterval
00094 #include "RooStats/MCMCInterval.h"
00095 #endif
00096 
00097 ClassImp(RooStats::MetropolisHastings);
00098 
00099 using namespace RooFit;
00100 using namespace RooStats;
00101 
00102 MetropolisHastings::MetropolisHastings()
00103 {
00104    // default constructor
00105    fFunction = NULL;
00106    fParameters = NULL;
00107    fPropFunc = NULL;
00108    fNumIters = 0;
00109    fNumBurnInSteps = 0;
00110    fSign = kSignUnset;
00111    fType = kTypeUnset;
00112 }
00113 
00114 MetropolisHastings::MetropolisHastings(RooAbsReal& function, RooArgSet& paramsOfInterest,
00115       ProposalFunction& proposalFunction, Int_t numIters)
00116 {
00117    fFunction = &function;
00118    SetParameters(paramsOfInterest);
00119    SetProposalFunction(proposalFunction);
00120    fNumIters = numIters;
00121    fNumBurnInSteps = 0;
00122    fSign = kSignUnset;
00123    fType = kTypeUnset;
00124 }
00125 
00126 MarkovChain* MetropolisHastings::ConstructChain()
00127 {
00128    if (!fParameters || !fPropFunc || !fFunction) {
00129       coutE(Eval) << "Critical members unintialized: parameters, proposal " <<
00130                      " function, or (log) likelihood function" << endl;
00131          return NULL;
00132    }
00133    if (fSign == kSignUnset || fType == kTypeUnset) {
00134       coutE(Eval) << "Please set type and sign of your function using "
00135          << "MetropolisHastings::SetType() and MetropolisHastings::SetSign()" <<
00136          endl;
00137       return NULL;
00138    }
00139 
00140    RooArgSet x;
00141    RooArgSet xPrime;
00142    x.addClone(*fParameters);
00143    RandomizeCollection(x);
00144    xPrime.addClone(*fParameters);
00145    RandomizeCollection(xPrime);
00146 
00147    MarkovChain* chain = new MarkovChain();
00148    chain->SetParameters(*fParameters);
00149 
00150    Int_t weight = 0;
00151    Double_t xL = 0.0, xPrimeL = 0.0, a = 0.0;
00152 
00153    RooFit::MsgLevel oldMsgLevel = RooMsgService::instance().globalKillBelow();
00154    RooMsgService::instance().setGlobalKillBelow(RooFit::ERROR);
00155 
00156 
00157    // We will need to check if log-likelihood evaluation left an error status.
00158    // Now using faster eval error logging with CountErrors.
00159    if (fType == kLog)
00160      fFunction->setEvalErrorLoggingMode(RooAbsReal::CountErrors);
00161 
00162    bool hadEvalError = true;
00163 
00164    Int_t i = 0;
00165    // get a good starting point for x
00166    // for fType == kLog, this means that fFunction->getVal() did not cause
00167    // an eval error
00168    // for fType == kRegular this means fFunction->getVal() != 0
00169    //
00170    // kbelasco: i < 1000 is sort of arbitary, but way higher than the number of
00171    // steps we should have to take for any reasonable (log) likelihood function
00172    while (i < 1000 && hadEvalError) {
00173       RandomizeCollection(x);
00174       RooStats::SetParameters(&x, fParameters);
00175       xL = fFunction->getVal();
00176 
00177       if (fType == kLog) {
00178          if (fFunction->numEvalErrors() > 0) {
00179             fFunction->clearEvalErrorLog();
00180             hadEvalError = true;
00181          } else
00182             hadEvalError = false;
00183       } else if (fType == kRegular) {
00184          if (xL == 0.0)
00185             hadEvalError = true;
00186          else
00187             hadEvalError = false;
00188       } else
00189          // for now the only 2 types are kLog and kRegular (won't get here)
00190          hadEvalError = false;
00191    }
00192 
00193    if(hadEvalError) {
00194       coutE(Eval) << "Problem finding a good starting point in " <<
00195                      "MetropolisHastings::ConstructChain() " << endl;
00196    }
00197 
00198    // do main loop
00199    for (i = 0; i < fNumIters; i++) {
00200       // reset error handling flag
00201       hadEvalError = false;
00202 
00203       if (i % (fNumIters / 100) == 0) {
00204          // print a dot every 1% of the chain construction
00205          fprintf(stdout, ".");
00206          fflush(NULL);
00207       }
00208 
00209       fPropFunc->Propose(xPrime, x);
00210 
00211       RooStats::SetParameters(&xPrime, fParameters);
00212       xPrimeL = fFunction->getVal();
00213 
00214       // check if log-likelihood for xprime had an error status
00215       if (fFunction->numEvalErrors() > 0 && fType == kLog) {
00216          xPrimeL = RooNumber::infinity();
00217          fFunction->clearEvalErrorLog();
00218          hadEvalError = true;
00219       }
00220 
00221       // why evaluate the last point again, can't we cache it?
00222       // kbelasco: commenting out lines below to add/test caching support
00223       //RooStats::SetParameters(&x, fParameters);
00224       //xL = fFunction->getVal();
00225 
00226       if (fType == kLog) {
00227          if (fSign == kPositive)
00228             a = xL - xPrimeL;
00229          else
00230             a = xPrimeL - xL;
00231       }
00232       else
00233          a = xPrimeL / xL;
00234       //a = xL / xPrimeL;
00235 
00236       if (!hadEvalError && !fPropFunc->IsSymmetric(xPrime, x)) {
00237          Double_t xPrimePD = fPropFunc->GetProposalDensity(xPrime, x);
00238          Double_t xPD      = fPropFunc->GetProposalDensity(x, xPrime);
00239          if (fType == kRegular)
00240             a *= xPD / xPrimePD;
00241          else
00242             a += TMath::Log(xPrimePD) - TMath::Log(xPD);
00243       }
00244 
00245       if (!hadEvalError && ShouldTakeStep(a)) {
00246          // go to the proposed point xPrime
00247 
00248          // add the current point with the current weight
00249          if (weight != 0.0)
00250             chain->Add(x, CalcNLL(xL), (Double_t)weight);
00251 
00252          // reset the weight and go to xPrime
00253          weight = 1;
00254          RooStats::SetParameters(&xPrime, &x);
00255          xL = xPrimeL;
00256       } else {
00257          // stay at the current point
00258          weight++;
00259       }
00260    }
00261 
00262    // make sure to add the last point
00263    if (weight != 0.0)
00264       chain->Add(x, CalcNLL(xL), (Double_t)weight);
00265    printf("\n");
00266 
00267    RooMsgService::instance().setGlobalKillBelow(oldMsgLevel);
00268 
00269    Int_t numAccepted = chain->Size();
00270    coutI(Eval) << "Proposal acceptance rate: " <<
00271                    numAccepted/(Float_t)fNumIters * 100 << "%" << endl;
00272    coutI(Eval) << "Number of steps in chain: " << numAccepted << endl;
00273 
00274    //TFile chainDataFile("chainData.root", "recreate");
00275    //chain->GetDataSet()->Write();
00276    //chainDataFile.Close();
00277 
00278    return chain;
00279 }
00280 
00281 Bool_t MetropolisHastings::ShouldTakeStep(Double_t a)
00282 {
00283    if ((fType == kLog && a <= 0.0) || (fType == kRegular && a >= 1.0)) {
00284       // The proposed point has a higher likelihood than the
00285       // current point, so we should go there
00286       return kTRUE;
00287    }
00288    else {
00289       // generate numbers on a log distribution to decide
00290       // whether to go to xPrime or stay at x
00291       //Double_t rand = fGen.Uniform(1.0);
00292       Double_t rand = RooRandom::uniform();
00293       if (fType == kLog) {
00294          rand = TMath::Log(rand);
00295          // kbelasco: should this be changed to just (-rand > a) for logical
00296          // consistency with below test when fType == kRegular?
00297          if (-1.0 * rand >= a)
00298             // we chose to go to the new proposed point
00299             // even though it has a lower likelihood than the current one
00300             return kTRUE;
00301       } else {
00302          // fType must be kRegular
00303          // kbelasco: ensure that we never visit a point where PDF == 0
00304          //if (rand <= a)
00305          if (rand < a)
00306             // we chose to go to the new proposed point
00307             // even though it has a lower likelihood than the current one
00308             return kTRUE;
00309       }
00310       return kFALSE;
00311    }
00312 }
00313 
00314 Double_t MetropolisHastings::CalcNLL(Double_t xL)
00315 {
00316    if (fType == kLog) {
00317       if (fSign == kNegative)
00318          return xL;
00319       else
00320          return -xL;
00321    } else {
00322       if (fSign == kPositive)
00323          return -1.0 * TMath::Log(xL);
00324       else
00325          return -1.0 * TMath::Log(-xL);
00326    }
00327 }

Generated on Tue Jul 5 15:14:35 2011 for ROOT_528-00b_version by  doxygen 1.5.1