SRN – “Scalable Metropolis-Hastings for Exact Bayesian Inference with Large Datasets” by Cornish et al.

When faced with large datasets, standard MCMC methods have a cost of O(n) per step due to evaluating the log-likelihood for n data points. Can we make it faster?

My COVID-19 self-isolation plan includes trying to finish as many to-read articles as possible and to blog my progress on this blog as much as possible. This weekend, I read the paper “Scalable Metropolis-Hastings for Exact Bayesian Inference with Large Datasets” by Robert Cornish, Paul Vanetti, Alexandre Bouchard-Cȏté, George Deligiannidis, and Arnaud Doucet. This is an ICML 2019 paper and it is a dense one. It attempts to solve one problem. When faced with large datasets,  standard MCMC methods have a cost of O(n) per step due to evaluating the log-likelihood for n data points. Can we make it faster? In a typical MH step (with symmetric proposals), we accept the proposal with probability \alpha(\theta,\theta') = \min(1, \frac{\pi(\theta')}{\pi(\theta)}). This is typically implemented with

log_alpha <- log_post(theta_proposal) - log_post(theta_current);
if (log(runif(1)) < log_alpha ){
    theta_current <- theta_proposal;

The authors realize that it is wasteful having to evaluate the log-likelihood at all data points before rejecting the proposal. Their solution, Salable Metropolis-Hastings (SMH), is based on a combination of three ideas: (1) factorization of posterior density (trivial factorization is 1 + n factors), (2) fast simulation of rejection using thinning of a Poisson point process, and (3) posterior concentration around mode when n is large. 

I first get excited about this paper seeing the fast simulation of Bernoulli random variables (in Algorithm 1) because there is a very similar problem on simulating a Poisson process with inhomogenous rates in Stat 210, for which I was a Teaching Fellow (TF). I am very happy to see this result being useful in cutting-edge research. The fast simulation relies on access to a lower bound for each factor in the posterior distribution and the tighter this lower bound the more efficient this implementation. Suppose the target distribution has m factors, then once we pay a once-off O(m) cost, each acceptance step can be implemented in O(upper bound of negative-log-posterior) number of steps on average. 

Another highlight of the paper is the smart application of the Bernstein-von Mises (BvM) phenomenon, that which states that posterior distribution concentrates around the mode with rate 1/\sqrt{n} as n goes to infinity. BvM was an important topic in Stat 213, which is another class that I TFed. I have seen BvM used (1) to select step-size in random-walk MH, (2) to initialize an MH chain, and (3) to motivate Laplace approximation of the posterior distribution. In this paper, it motivates a Taylor expansion of the potential function near posterior mode, which subsequently leads to the factorization behind the SMH kernel. The SMH kernel leaves the target distribution invariant, and, if not mistaken by me, this property explains the use of “exact” in the title. More importantly, since the acceptance probability in SMH is factorizable, we can implement the acceptance step with the aforementioned fast simulation, provided that log-density can be differentiated enough times. Intuitively, the higher the differentiability, the closer the Taylor expansion, the tighter the bound for acceptance rate and the fewer the number of evaluations we need. In fact, as later justified with Theorem 3.1, the average computational cost decays geometrically with respect to the order of Taylor expansion. With a second-order Taylor expansion, which requires 3rd-order differentiability, the expected computational cost behaves like O(1/\sqrt{n}). Hooray! I am always a big fan of “the bigger the problem, the easier the solution”. 

I learn a lot from this paper, especially from the opening paragraph of Section 3 and Section 3.2. But there is a question that I keep asking myself: if n is very large and we have an estimate of posterior mode that is O(1/\sqrt{n}) distance from the true mode, how long does it take for SMH to be better than that declaring the posterior is Gaussian by invoking BvM and sleeping on the result?

Author: PhyllisWithData

Statistics PhD student at Harvard University.

Leave a Reply

Fill in your details below or click an icon to log in: Logo

You are commenting using your account. Log Out /  Change )

Google photo

You are commenting using your Google account. Log Out /  Change )

Twitter picture

You are commenting using your Twitter account. Log Out /  Change )

Facebook photo

You are commenting using your Facebook account. Log Out /  Change )

Connecting to %s