Jekyll2024-01-22T11:32:49+00:00https://mlg.eng.cam.ac.uk/blog/feed.xmlMLG BlogBlog of the Machine Learning Group at the University of CambridgeNatural-Gradient Variational Inference 2: ImageNet-scale2021-11-24T00:00:00+00:002021-11-24T00:00:00+00:00https://mlg.eng.cam.ac.uk/blog/2021/11/24/ngvi-bnns-part-2<p>In our <a href="https://mlg-blog.com/2021/04/13/ngvi-bnns-part-1.html">previous post</a>, we derived a natural-gradient variational inference (NGVI) algorithm for neural networks, detailing all our approximations and providing intuition. We saw it converge faster than more naive variational inference algorithms on relatively small-scale data. But a couple of key questions remain:</p>
<ol class="custom parentheses_roman">
<li>Can we scale such algorithms to very large datasets and architectures?</li>
<li>Did we gain anything from having additional Bayesian principles, or put differently, do we have better performance than SGD or Adam?</li>
</ol>
<p>We tackle these questions in this second part of the natural-gradient for variational inference series.
We show that we can get good performance at large scales with Bayesian principles, while maintaining reasonable uncertainties.
We start by focussing on question (i): the issue of scalability.
We notice similarities between our NGVI algorithm and Adam, and exploit this to borrow tricks that the community has developed for Adam over many years. This allows us to scale up to very large datasets/architectures.
We then turn our focus to question (ii): have we improved on neural networks’ poorly-calibrated uncertainties thanks to our Bayesian thinking? We will see some benefits. Along the way, we will discuss the price we pay for them.</p>
<p>This second part of the blog closely follows a paper I was involved in, <a href="https://arxiv.org/pdf/1906.02506.pdf">Practical Deep Learning with Bayesian Principles (Osawa et al., 2019)</a>. There is also a <a href="https://github.com/team-approx-bayes/dl-with-bayes">codebase</a> if you are interested in experimenting with our algorithm, VOGN (Variational Online Gauss-Newton). As a postscript to this blog post, we summarise some good practices for training your own neural network with VOGN.
$\newcommand{\vparam}{\boldsymbol{\theta}}$
$\newcommand{\veta}{\boldsymbol{\eta}}$
$\newcommand{\vphi}{\boldsymbol{\phi}}$
$\newcommand{\vmu}{\boldsymbol{\mu}}$
$\newcommand{\vSigma}{\boldsymbol{\Sigma}}$
$\newcommand{\vm}{\mathbf{m}}$
$\newcommand{\vF}{\mathbf{F}}$
$\newcommand{\vI}{\mathbf{I}}$
$\newcommand{\vg}{\mathbf{g}}$
$\newcommand{\vH}{\mathbf{H}}$
$\newcommand{\vs}{\mathbf{s}}$
$\newcommand{\myexpect}{\mathbb{E}}$
$\newcommand{\pipe}{\,|\,}$
$\newcommand{\data}{\mathcal{D}}$
$\newcommand{\loss}{\mathcal{L}}$
$\newcommand{\gauss}{\mathcal{N}}$</p>
<h2 id="vogn-vs-adam">VOGN vs Adam</h2>
<p>We start with the equations for the VOGN algorithm, derived in our <a href="https://mlg-blog.com/2021/04/13/ngvi-bnns-part-1.html">previous blog post</a>. This also serves as a quick summary of notation: please look at the previous blog post if anything is unclear! (Colours are purely for illustrative purposes.)</p>
<p>\begin{align}
\label{eq:VOGN_mu}
\vmu_{t+1} &= \vmu_t - \alpha_t \frac{ {\color{purple}\hat{\vg}(\vparam_t)} + {\color{blue}\tilde{\delta}}\vmu_t}{\vs_{t+1} + {\color{blue}\tilde{\delta}}}, \newline
\label{eq:VOGN_Sigma}
\vs_{t+1} &= (1-\beta_t)\vs_t + \beta_t \frac{1}{M} \sum_{i\in\mathcal{M}_t}\left( {\color{purple}\vg_i(\vparam_t)} \right)^2.
\end{align}</p>
<p>Remember, we are updating the parameters of our (mean-field Gaussian) approximate posterior $q_t(\vparam)=\gauss(\vparam; \vmu_t, \vSigma_t)$, where $\vparam$ are the parameters of a neural network. We do this by iteratively updating two vectors, $\vmu_t$ and $\vs_t$, where $t$ indexes the iteration.
We have a zero-mean prior $p(\vparam) = \gauss(\vparam; \color{blue}\boldsymbol{0}, {\color{blue}\delta^{-1}} \vI)$, and ${\color{blue}\tilde{\delta}} = {\color{blue}{\delta}} / N$.
Our dataset consists of $N$ data examples, and we are taking per-example gradients of the negative log-likelihood $\color{purple}\vg_i(\vparam_t)$ at a sample from our current approximate posterior, $\vparam_t \sim q_t(\vparam)$.
For a randomly-sampled minibatch $\mathcal{M}_t$ of size $M$, we have defined the average gradient ${\color{purple}\hat{\vg}(\vparam_t)} = \frac{1}{M}\sum_{i\in\mathcal{M}_t} {\color{purple}\vg_i(\vparam_t)}$.
There is a simple relation between $\vSigma_t$ and $\vs_t$, $\vSigma_t^{-1} = N\vs_t + {\color{blue}\delta \vI}$.
Finally, $\alpha_t>0$ and $0<\beta_t<1$ are learning rates, and all operations are element-wise.</p>
<p>It turns out that this update equation is very similar to Adam (<a href="https://arxiv.org/pdf/1412.6980.pdf">Kingba & Ba, 2015</a>). To see this, let’s write down the form that commonly-used optimisers take, such as SGD, RMSProp (<a href="https://www.cs.toronto.edu/~tijmen/csc321/slides/lecture_slides_lec6.pdf">Hinton, 2012</a>), and Adam:</p>
<p>\begin{align}
\label{eq:Adam_mu}
\vmu_{t+1} &= \vmu_t - \alpha_t \frac{ {\color{purple}\hat{\vg}(\vmu_t)} + {\color{blue}\delta}\vmu_t} {\sqrt{\vs_{t+1}} + \epsilon}, \newline
\label{eq:Adam_Sigma}
\vs_{t+1} &= (1-\beta_t)\vs_t + \beta_t \left( \frac{1}{M} \sum_{i\in\mathcal{M}_t} {\color{purple}\vg_i(\vmu_t)} + {\color{blue}\delta} \vmu_t \right) ^2,
\end{align}</p>
<p>where $\delta>0$ is our weight-decay regulariser, and $\epsilon>0$ is a small scalar constant.
Immediately we can see striking similarities in the overall form of the equations! Let’s take a closer look at the similarities and differences:</p>
<ol>
<li><em>Similarity</em>: Both updates for $\vmu_t$ are similar, of the form $\vmu_{t+1} = \vmu_t - \alpha_t (\hat{\vg} + \delta \vmu_t) / \mathrm{function}{(\vs_{t+1})}$.</li>
<li><em>Difference</em>: The denominator in the update for the means is slightly different. VOGN uses $(\vs_{t+1} + \tilde{\delta})$, while Adam has a square root, $\sqrt{\vs_{t+1}}$.</li>
<li><em>Difference</em>: VOGN calculates gradients at a sample $\vparam_t \sim q_t(\vparam)$, while Adam calculates gradients just at the mean $\vmu_t$. In fact, when we remove this difference, we get a deterministic version of VOGN, which we call OGN.</li>
<li><em>Similarity</em>: Both updates for $\vs_t$ take the form of a moving average update.</li>
<li><em>Difference</em>: VOGN uses a Gauss-Newton approximation, requiring $\sum_i (\vg_i)^2$, while Adam (and other SGD-based algorithms) use a gradient-magnitude, $\left( \sum_i \vg_i \right) ^2$. Note that in VOGN, the sum is <em>outside</em> the square, while in SGD-based algorithms, the sum is <em>inside</em> the square.</li>
</ol>
<p>The Gauss-Newton approximation (Difference 5) is a better approximation to second-order information (Hessian) than the gradient-magnitude approach. This better approximation is likely why VOGN does not require a square root over $\vs_{t+1}$ in the update for the means (Difference 2). However, calculating the Gauss-Newton approximation requires additional computation in frameworks such as PyTorch, leading to VOGN being slower (per epoch) compared to Adam.
This is despite using speed-up tricks (<a href="https://arxiv.org/pdf/1510.01799.pdf">Goodfellow, 2015</a>).</p>
<p>The similarities in the equations indicate that we might be able to take techniques people use to scale Adam up to large datasets and architectures, and apply similar ideas to scale VOGN up.
We can use batch normalisation, momentum, clever initialisations, data augmentation, learning rate scheduling, and so on.</p>
<h2 id="borrowing-existing-deep-learning-techniques">Borrowing existing deep-learning techniques</h2>
<p>Let’s go over a list of each of the changes we make, providing some intuition for them. Please see <a href="https://arxiv.org/pdf/1906.02506.pdf">Osawa et al. (2019)</a> for further details. Using all these techniques, we are able to scale VOGN to datasets like CIFAR-10 and ImageNet, and architectures such as ResNets.</p>
<h3 id="1-batch-normalisation-and-momentum">1. Batch normalisation and momentum</h3>
<p>Batch normalisation (<a href="http://arxiv.org/pdf/1502.03167.pdf">Ioffe & Szegedy, 2015</a>) empirically speeds up and stabilises training of neural networks.
We can use BatchNorm layers as is normal in deep learning. In fact, in our VOGN implementation, we found that we do not have to maintain uncertainty over BatchNorm parameters.</p>
<p>We can also use momentum for VOGN in a similar way to Adam: we introduce momentum on $\vmu_t$, along with a momentum rate.</p>
<h3 id="2-initialisation">2. Initialisation</h3>
<p>Over many years of training neural networks with SGD and Adam, the community has found tricks to speed up training using clever initialisation. We can get these same benefits by changing VOGN to look more like Adam at initialisation, before slowly relaxing our algorithm to become the full VOGN algorithm later in training.</p>
<p>This is achieved by introducing a tempering parameter $\tau$ in front of the KL term in the ELBO, which propagates its way through to the VOGN equations. To see exactly where $\tau$ crops up, please look at Equation 4 from <a href="https://arxiv.org/pdf/1906.02506.pdf">Osawa et al. (2019)</a>, or see <a href="#figure-VOGNalgorithm">Algorithm 1</a> below. As $\tau\rightarrow 0$, we (loosely speaking) get more similar to Adam. So, at the beginning of training, we initialise $\tau$ at something small (like $0.1$) and increase to $1$ during the first few optimisation steps.</p>
<p>Other initialisations are the same as Adam: $\vmu_t$ is initialised using <code class="language-plaintext highlighter-rouge">init.xavier_normal</code> from PyTorch (<a href="https://proceedings.mlr.press/v9/glorot10a/glorot10a.pdf">Glorot & Bengio, 2010</a>) and the momentum term is initialised to zero, like in Adam. VOGN’s $\vs_t$ is initialised using an additional forward pass through the first minibatch of data.</p>
<h3 id="3-learning-rate-scheduling">3. Learning rate scheduling</h3>
<p>We can use learning rate scheduling for $\alpha_t$ exactly like is used for Adam and SGD at a large scale. We regularly decay the learning rate by a factor (typically a factor of 10).</p>
<h3 id="4-data-augmentation">4. Data augmentation</h3>
<p>When training on image datasets, data augmentation (DA) can improve performance drastically. For example, we can use random cropping and/or random horizontal flipping of images.</p>
<p>Unfortunately, directly applying DA to VOGN does not lead to improvements, and also negatively affects uncertainty calibration.
But we note that DA can be viewed as increasing the effective size of our dataset: remember that our dataset size $N$ affects VOGN (as opposed to Adam and SGD, where $N$ does not appear, as it unidentifiable with the weight-decay factor). So, we view DA as increasing $N$ by some factor depending on the exact DA technique: for example, if we horizontally flip each image with a probability of 50%, we increase $N$ by a factor of 2.</p>
<p>This is still a heuristic, and not mathematically rigorous. It seems to work quite well in our experiments, but requires further theoretical investigation. It is also closely related to KL-annealing in variational inference, as well as the recently-termed ‘cold posterior effect’ (<a href="https://arxiv.org/pdf/2002.02405.pdf">Wenzel et al., 2020</a>; <a href="https://arxiv.org/pdf/2011.12328.pdf">Loo et al., 2021</a>; <a href="https://arxiv.org/pdf/2008.05912.pdf">Aitchison, 2021</a>).</p>
<h3 id="5-distributed-training">5. Distributed training</h3>
<p>We would like to use multiple GPUs in parallel to perform large experiments quickly. Typically, we would just parallelise data, splitting up large minibatch sizes by sending different data to different GPUs. With VOGN, we can also parallelise computation over Monte-Carlo samples $\vparam_t \sim q_t(\vparam)$. Every GPU can use a different sample $\vparam_t$. This reduces variance during training, and we empirically find it leads to quicker convergence.</p>
<h3 id="6-external-damping-factor">6. External damping factor</h3>
<p>We introduce an external damping factor $\gamma$, added to $\vs_{t+1}$ in the denominator of Equation \eqref{eq:VOGN_mu} (<a href="https://arxiv.org/pdf/1712.02390.pdf">Zhang et al., 2018</a>).
This increases the lower bound of the eigenvalues on the diagonal covariance $\vSigma_t$, preventing the step size and noise from becoming too large. However, this also detracts from the principled variational inference equations, and there is currently no theoretical justification for this (beyond the intuition we just provided).</p>
<h3 id="final-algorithm">Final algorithm</h3>
<p>Let’s recap. We derived the VOGN equations (Equations \eqref{eq:VOGN_mu} and \eqref{eq:VOGN_Sigma}) in the <a href="https://mlg-blog.com/2021/04/13/ngvi-bnns-part-1.html">previous blog post</a>. We started this post by comparing the equations to Adam, noting key similarities and differences. One key difference was based off the Gauss-Newton approximation, which slows VOGN down (per epoch) compared to SGD-based algorithms like Adam. Based on the similarities, we borrowed tricks to scale Adam to large data settings, and applied them to VOGN.</p>
<p>All of these tricks are important to get VOGN’s results on ImageNet. The final algorithm is summarised in <a href="#figure-VOGNalgorithm">Algorithm 1</a> below.
One downside of VOGN when compared to Adam is the additional hyperparameters that require tuning. At the end of this blog post, we provide best practices for tuning these hyperparameters.</p>
<div class="image-container">
<img src="/blog/assets/images/ngvi-bnns/VOGN_algorithm_figure.png" alt="Final algorithm, ready for running on ImageNet. Additional notes explaining key steps are in red. The vanilla VOGN equations (Equations \eqref{eq:VOGN_mu} and \eqref{eq:VOGN_Sigma}) are in Steps 8–12 & 18–19. The final list of hyperparameters are summarised in the bottom right. The final four hyperparameters are specific to VOGN, and we provide best practices for tuning them at the end of the blog post." id="figure-VOGNalgorithm" style="width: 100%; max-width: 700px" />
<p class="caption">
Algorithm 1: Final algorithm, ready for running on ImageNet. Additional notes explaining key steps are in red. The vanilla VOGN equations (Equations \eqref{eq:VOGN_mu} and \eqref{eq:VOGN_Sigma}) are in Steps 8–12 & 18–19. The final list of hyperparameters are summarised in the bottom right. The final four hyperparameters are specific to VOGN, and we provide best practices for tuning them at the end of the blog post.
</p>
</div>
<h2 id="results-on-imagenet">Results on ImageNet</h2>
<p>We are finally in a place to run VOGN on ImageNet and analyse results. We take <a href="#figure-VOGNalgorithm">Algorithm 1</a>and run it on ImageNet.</p>
<div class="image-container">
<img src="/blog/assets/images/ngvi-bnns/ImageNet_results.png" alt="Results on ImageNet. VOGN converges in about as many epochs as Adam and SGD (top left plot), but is almost twice as slow per epoch (top middle plot). VOGN's calibration is better (top right plot). Overall, VOGN gets good accuracy and uncertainty metrics. See the paper for standard deviations over many runs." id="figure-ImageNetResults" style="width: 100%; max-width: 700px" />
<p class="caption">
Figure 1: Results on ImageNet. VOGN converges in about as many epochs as Adam and SGD (top left plot), but is almost twice as slow per epoch (top middle plot). VOGN's calibration is better (top right plot). Overall, VOGN gets good accuracy and uncertainty metrics. See the paper for standard deviations over many runs.
</p>
</div>
<p>Let’s go through these results slowly.</p>
<ul>
<li>Top left plot: VOGN’s convergence rate (per epoch) is very similar to Adam’s. The step increases (at epochs 30 and 60) are due to learning rate scheduling, which is best practice for training algorithms on ImageNet.</li>
<li>Top middle plot: VOGN is about twice as slow (total time) compared to SGD and Adam. This is impressive compared to other approaches like Bayes-By-Backprop (<a href="https://arxiv.org/pdf/1505.05424.pdf">Blundell et al., 2015</a>), which currently can’t scale to ImageNet even if given an order of magnitude more time!</li>
<li>Top right plot: In this calibration curve, the red line is closer to the diagonal than the other lines, showing better calibration. This plot is summarised in the ECE (Expected Calibration Error) column in the Table, where VOGN is better than SGD and Adam.</li>
<li>Turning our attention to the Table, MC-dropout gets very good ECE, but this is at the cost of validation accuracy, and only achieved after a fine-grain sweep of hyperparameters (specifically the dropout rate, see Appendix G in the paper).</li>
<li>OGN is a deterministic version of VOGN, where we do not use the reparameterisation trick to sample $\vparam_t$ during training (Steps 8 & 9 in <a href="#figure-VOGNalgorithm">Algorithm 1</a>), and instead just use the mean $\vmu_t$.</li>
<li>K-FAC has a Kronecker-factored structure as in <a href="https://arxiv.org/pdf/1811.12019.pdf">Osawa et al. (2018)</a>, where they impressively trained on ImageNet in very few iterations. <a href="https://towardsdatascience.com/introducing-k-fac-and-its-application-for-large-scale-deep-learning-4e3f9b443414">This blog post</a> provides an introduction to K-FAC at a large scale.</li>
<li>OGN and K-FAC perform well, but their metrics (particularly validation accuracy, validation negative-log-likelihood and ECE) are worse than VOGN.</li>
<li>Noisy K-FAC (<a href="https://arxiv.org/pdf/1712.02390.pdf">Zhang et al., 2018</a>) takes a similar algorithm to VOGN and adds K-FAC structure to the covariance matrix. It is therefore more computationally expensive than VOGN (slower per epoch and total training time), but requires fewer epochs. It performs decently, but not as well as VOGN in this experiment.</li>
</ul>
<p>There are many more results in <a href="https://arxiv.org/pdf/1906.02506.pdf">Osawa et al. (2019)</a>, such as CIFAR-10 with a variety of architectures. We tend to see a similar story, where VOGN performs comparably on validation accuracy, and well on uncertainty metrics.</p>
<h3 id="some-bayesian-trade-offs">Some Bayesian trade-offs</h3>
<p>Due to the Bayesian nature of VOGN, we see some interesting trade-offs (see the paper for figures).</p>
<ol>
<li>
<p>Reducing the prior precision $\delta$ results in higher validation accuracy, but also a larger train-test gap, corresponding to more overfitting. With very small prior precisions, performance is similar to non-Bayesian methods like Adam.</p>
</li>
<li>
<p>Increasing the number of training MC samples ($K$ in <a href="#figure-VOGNalgorithm">Algorithm 1</a>) improves VOGN’s convergence rate and stability, as it reduces gradient variance during training. But this is at the cost of increased computation. Increasing the number of MC samples during testing improves generalisation.</p>
</li>
</ol>
<h2 id="downstream-uncertainty-performance">Downstream uncertainty performance</h2>
<p>If you are like me, metrics such as negative-log-likelihood and expected calibration error do not mean much when it comes to analysing if your algorithm has ‘better uncertainty’.
We should also test on downstream tasks to see how reliable our methods are, and more and more papers are starting to do so (see also this year’s <a href="https://neurips.cc/Conferences/2021/Schedule?showEvent=21827">NeurIPS Bayesian Deep Learning workshop</a>, which makes this a priority). The VOGN paper tested on two downstream tasks: continual learning and out-of-distribution performance.</p>
<p><strong>Continual Learning</strong>: I personally think continual learning is a very good way to test approximate Bayesian inference algorithms, particularly variational deep-learning algorithms. We tested VOGN on Permuted MNIST, finding it performs as well as VCL (<a href="https://arxiv.org/pdf/1710.10628.pdf">Nguyen et al., 2018</a>; <a href="https://arxiv.org/pdf/1905.02099.pdf">Swaroop et al., 2019</a>), but trained more than an order of magnitude quicker. More recently, VOGN has also achieved good results on a bigger Split CIFAR benchmark (see Section 4.5 of <a href="https://team-approx-bayes.github.io/assets/rgroups/thesis_runa_eschenhagen.pdf">Eschenhagen (2019)</a>), which VCL struggles to scale to.</p>
<p><strong>Out-of-distribution performance</strong>: We also tested VOGN on standard out-of-distribution benchmarks, such as training on CIFAR-10 and testing on SVHN and LSUN. Figure 5 in the paper shows results (histograms of predictive entropy), where we qualitatively see VOGN performing well.</p>
<h2 id="conclusions-and-further-reading">Conclusions and further reading</h2>
<p>In the first blog post, we derived VOGN, our natural-gradient variational inference algorithm. In this blog post, we scaled it all the way to ImageNet. We made approximations along the way, but by being clever about when and where to make approximations, we have ended up with a practical algorithm that has Bayesian principles. Our final algorithm therefore performs reasonably well in downstream tasks.</p>
<p>It has been two years since publishing VOGN’s performance on ImageNet, and the field continues to move at break-neck pace. More algorithms and more benchmarks continue to be published, as well as more insight into VI.</p>
<ul>
<li>If you are interested in the maths of improving natural-gradient variational inference algorithms, I particularly recommend work by Wu Lin and co-authors. They looked at improving VON (same as VOGN but without the Gauss-Newton matrix approximation), deriving another quick NGVI algorithm that can perform well (<a href="https://arxiv.org/pdf/2002.10060.pdf">Lin et al., 2020</a>).
They have also expanded to mixtures of exponential family distributions (<a href="https://arxiv.org/pdf/1906.02914.pdf">Lin et al., 2019</a>), and looked at structured natural gradient descent, drawing links to Newton-like algorithms (<a href="https://arxiv.org/pdf/2102.07405.pdf">Lin et al., 2021</a>).</li>
<li>There is also interesting work looking at pathologies of mean-field VI on neural networks (VOGN is a mean-field VI algorithm). There are problems in the single-hidden-layer setting (<a href="https://arxiv.org/pdf/1909.00719.pdf">Foong et al., 2020</a>), and problems in the wide limit (<a href="https://arxiv.org/pdf/2106.07052.pdf">Coker et al., 2021</a>).</li>
</ul>
<h2 id="acknowledgements">Acknowledgements</h2>
<p>Firstly, many thanks to my co-authors Kazuki Osawa, Anirudh Jain, Runa Eschenhagen, Richard Turner, Rio Yokota and Emtiyaz Khan, many of whom also provided valuable feedback on these blog posts. I would also like to thank Andrew Foong, Wessel Bruinsma and Stratis Markou for their comments during drafting of these blog posts.</p>
<h2 id="post-scipt-a-guide-on-how-to-tune-vogn">Post-scipt: A guide on how to tune VOGN</h2>
<p>As we saw in <a href="#figure-VOGNalgorithm">Algorithm 1</a>, there are many hyperparameters that need tuning for VOGN (and generally for VI at a large-scale). Here we briefly summarise how we did this in <a href="https://arxiv.org/pdf/1906.02506.pdf">Osawa et al. (2019)</a>, following the guidelines presented in Section 3 of the paper. The key idea is to make sure VOGN training closely follows Adam’s trajectory in the beginning of training.</p>
<ol>
<li>
<p>First, we tune hyperparameters for OGN, which is the same as VOGN except setting $\vparam_t=\vmu_t$ (no MC sampling). OGN is more stable than VOGN and is a convenient stepping stone as we move from Adam to VOGN. So, we initialise OGN’s hyperparameters at Adam’s values, and tune until OGN is competitive with Adam. This requires tuning learning rates, prior precision $\delta$, and setting a suitable value for the data augmentation factor (if using data augmentation).</p>
</li>
<li>
<p>Then, we move to VOGN. Now, we (fine-)tune the prior precision $\delta$, warm-start the tempering parameter $\tau$ (such as increasing $\tau$ from $0.1\rightarrow1$ during the first few optimisation steps), and the number of MC samples $K$ for VOGN (more samples means more stable training, but more computation cost). We also now consider adding an external damping factor $\gamma$ if required.</p>
</li>
</ol>Siddharth SwaroopHaving derived a natural-gradient variational inference algorithm, we now turn our attention to scaling it all the way to ImageNet. By borrowing tricks developed for Adam, we can get fast convergence, good performance, and reasonable uncertainties.Bayesian Deep Learning via Subnetwork Inference2021-07-21T00:00:00+00:002021-07-21T00:00:00+00:00https://mlg.eng.cam.ac.uk/blog/2021/07/21/subnetwork-inference<h2 id="motivation-bayesian-deep-learning-is-important-but-hard">Motivation: Bayesian deep learning is important but hard</h2>
<p>Despite their successes, accurately quantifying uncertainty in the predictions of DNNs is notoriously hard, especially if there is a shift in the data distribution between train and test time.
In practice, this might often lead to overconfident predictions, which is particularly harmful in safety-critical applications such as healthcare and autonomous driving.
One principled approach to quantify the predictive uncertainty of a neural net is to use <em>Bayesian inference</em>.</p>
<p>The standard practice in deep learning is to estimate the parameters using just a <em>single point</em> found through gradient-based optimisation. In contrast, in Bayesian deep learning (check out <a href="https://jorisbaan.nl/2021/03/02/introduction-to-bayesian-deep-learning.html">this blog post</a> for an introduction to Bayesian deep learning), the goal is to infer a <em>full posterior distribution</em> over the model’s weights. By capturing a distribution over weights, we capture a distribution over neural networks, which means that prediction essentially takes into account the predictions of (infinitely) many neural networks. Intuitively, on data points that are very distinct from the training data, these different neural nets will disagree on their predictions. This will result in high <em>predictive uncertainty</em> on such data points and therefore reduce overconfidence.</p>
<p>The problem is that modern deep neural nets are so big that even trying to approximate this posterior is highly non-trivial. We are not even talking about humongous 100-billion-parameter models like OpenAI’s GPT-3 here (<a href="https://arxiv.org/abs/2005.14165">Brown <em>et al.</em>, 2020</a>) — even for a neural net with more than just a few layers it’s hard to do good posterior inference! Therefore, it’s becoming more and more challenging to design approximate inference methods that actually scale.</p>
<p>To cope with this problem, many existing Bayesian deep learning methods make very strong and unrealistic approximations to the structure of the posterior. For example, the common <em>mean field approximation</em> approximates the posterior by a distribution which fully factorises over individual weights. Unfortunately, recent papers (<a href="https://arxiv.org/abs/1906.02530">Ovadia <em>et al.</em>, 2019</a>; <a href="https://arxiv.org/abs/1906.11537">Foong <em>et al.</em>, 2019</a>) have empirically demonstrated that such strong assumptions result in bad performance on downstream tasks such as uncertainty estimation. Can we do better than this?
$\newcommand{\vy}{\mathbf{y}}$
$\newcommand{\vw}{\mathbf{w}}$
$\newcommand{\mH}{\mathbf{H}}$
$\newcommand{\mX}{\mathbf{X}}$
$\newcommand{\D}{\mathcal{D}}$
$\newcommand{\c}{\textsf{c}}$</p>
<h2 id="idea-do-inference-over-only-a-small-subset-of-the-model-parameters">Idea: Do inference over only a small subset of the model parameters!</h2>
<p>Most existing Bayesian deep learning methods try to do inference over all the weights of the neural net. But do we actually need to estimate a posterior distribution over all weights?</p>
<p>It turns out that you often don’t need all those weights. In particular, recent research (<a href="https://arxiv.org/abs/1710.09282">Cheng <em>et al.</em>, 2017</a>) has shown that, since deep neural nets are so heavily overparametrised, it’s possible to find a small subnetwork within a neural net containing only a very small fraction of the weights, which, miraculously, can achieve the same accuracy as the full neural net. These subnetworks can be found by so-called pruning techniques.</p>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/pruning.png" alt="An illustration of neural network pruning (Han et al., 2015)." id="figure-pruning" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 1: An illustration of neural network pruning (<a href="https://arxiv.org/abs/1506.02626">Han <i>et al.</i>, 2015</a>).
</p>
</div>
<p>As shown in <a href="#figure-pruning">Figure 1</a>, pruning techniques typically first train the neural net, and then, after training, remove certain weights or even entire neurons according to some criterion. There has been a lot of recent interest in this research direction; for example, the best paper award at ICLR 2019 went to Jonathan Frankle and Michael Carbin’s now famous work on the lottery ticket hypothesis (<a href="https://arxiv.org/abs/1803.03635">Frankle and Carbin, 2018</a>), which showed that you can even retrain the pruned network from scratch and still achieve the same accuracy as the full network.</p>
<p>But how does this help us? We asked ourselves the exact same question about the model uncertainty: Can a full deep neural net’s model uncertainty be well-preserved by a small subnetwork’s model uncertainty? It turns out that the answer is yes, and in the remainder of this blog post, you will learn about how we came to this conclusion.</p>
<h2 id="our-proposed-approximation-to-the-posterior">Our proposed approximation to the posterior</h2>
<p>Assume that we have divided the weights $\vw$ into two disjoint subsets: (1) the subnetwork $\vw_S$ and (2) the set of all remaining weights $\{\vw_r\}_{r \in S^\c}$. We will later describe how we select the subnetwork; for now, just assume that we have it already. We propose to approximate the posterior distribution over as follows:
\begin{equation}
p(\vw \cond \D)
\overset{\text{(i)}}{\approx}
p(\vw_S \cond \D)
\prod_{r \in S^\c} \delta(\vw_r - \widehat{\vw}_r)
\overset{\text{(ii)}}{\approx}
q(\vw_S)
\prod_{r \in S^\c} \delta(\vw_r - \widehat{\vw}_r)
=: q_S(\vw).
\end{equation}
The first step (i) of our posterior approximation then involves a posterior distribution over just the subnetwork $\vw_S$, and delta functions over all remaining weights $\{\vw_r\}_{r \in S^\c}$. Put differently, we only treat the subnetwork $\vw_S$ in a probabilistic way, and assume that each remaining weight $\vw_r$ is deterministic and set to some fixed value $\widehat\vw_r$. Unfortunately, exact inference over the subnetwork is still intractable, so, in the second step (ii) of our approximation, we introduce an approximate posterior $q$ over the subnetwork $\vw_S$. Importantly, as the subnetwork is much smaller than the full network, this allows us to use expressive posterior approximations that would otherwise be computationally intractable (<em>e.g.</em> full-covariance Gaussians). That’s it.</p>
<p>There are a few questions that we still need to answer:</p>
<ol class="custom questions">
<li>How do we choose and infer the subnetwork posterior $q(\vw_S)$? That is, what form does $q$ have, and how do we infer its parameters?</li>
<li>How do we set the fixed values $\widehat\vw_r$ of all remaining weights $\{\vw_r\}_{r \in S^\c}$?</li>
<li>How do we select the subnetwork $\vw_S$ in the first place?</li>
<li>How do we make predictions with this approximate posterior?</li>
<li>How does subnetwork inference perform in practice?</li>
</ol>
<p>Let’s start with Q1.</p>
<h2 id="q1-how-do-we-choose-and-infer-the-subnetwork-posterior-">Q1. How do we choose and infer the subnetwork posterior ?</h2>
<p>In this work, we infer a full-covariance Gaussian posterior over the subnetwork using the Laplace approximation, which is a classic approximate inference technique. If you don’t recall how the Laplace approximation works, below we provide a short summary. For more details on the Laplace approximation and a review of its use in deep learning, please refer to <a href="https://arxiv.org/abs/2106.14806">Daxberger et al. (2021)</a>.</p>
<p>The Laplace approximation proceeds in two steps.</p>
<ol>
<li>
<p>Obtain a point estimate over all model weights using maximum a-posteriori (short MAP) inference. In deep learning, this is typically done using stochastic gradient-based optimisation methods such as SGD.
\begin{equation}
\widehat\vw = \argmax_{\vw} \, [\log p(\vy \cond \mX, \vw) + \log p(\vw)]
\end{equation}</p>
</li>
<li>
<p>Locally approximate the log-density of the posterior with a second-order Taylor expansion. This produces a full-covariance Gaussian posterior over the weights, where the mean of the Gaussian is simply the MAP estimate, and the covariance matrix of the Gaussian is the inverse Hessian $\mH$ of the loss with respect to the weights $\vw$ (averaged over the training data points):
\begin{equation}
p(\vw \cond \D) \approx q(\vw) = \Normal(\vw \cond \widehat\vw, \mH^{-1}).
\end{equation}</p>
</li>
</ol>
<p>What this essentially does is it defines a Gaussian centered at the MAP estimate, with a covariance matrix that matches the curvature of the loss at the MAP estimate, as illustrated in <a href="#figure-laplace">Figure 2</a>.</p>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/laplace.png" alt="A conceptual illustration of the Laplace approximation in one dimension (image adapted with kind permission from Richard Turner). We plot the parameter $\mathbf{w}$ ($x$-axis) against the density of the true posterior $p(\mathbf{w}\cond\mathcal{D})$ (in black) as well as that of the corresponding Laplace approximation $q(\mathbf{w})$ (in red). As we can see, $q(\mathbf{w})$ is a Gaussian centered at the mode $\widehat{\mathbf{w}}$ of the posterior $p(\mathbf{w}\cond\mathcal{D})$, with covariance matrix matching the curvature of $p(\mathbf{w}\cond\mathcal{D})$ at $\widehat{\mathbf{w}}$." id="figure-laplace" style="width: 100%; max-width: 400px" />
<p class="caption">
Figure 2: A conceptual illustration of the Laplace approximation in one dimension (image adapted with kind permission from Richard Turner). We plot the parameter $\mathbf{w}$ ($x$-axis) against the density of the true posterior $p(\mathbf{w}\cond\mathcal{D})$ (in black) as well as that of the corresponding Laplace approximation $q(\mathbf{w})$ (in red). As we can see, $q(\mathbf{w})$ is a Gaussian centered at the mode $\widehat{\mathbf{w}}$ of the posterior $p(\mathbf{w}\cond\mathcal{D})$, with covariance matrix matching the curvature of $p(\mathbf{w}\cond\mathcal{D})$ at $\widehat{\mathbf{w}}$.
</p>
</div>
<p>The main advantage of the Laplace approximation, and also the reason why we use it, is that it is applied <em>post-hoc</em> on top of a MAP estimate and doesn’t require us to re-train the network. This is practically very appealing as MAP estimation is something we can do very well in deep neural nets. The main issue, however, is that it requires us to compute, store, and invert the full Hessian $\mH$ over all weights. This scales quadratically in space and cubically in time (in terms of the number of weights) and is therefore computationally intractable for modern neural nets.</p>
<p>Fortunately, in our case, we don’t actually want to do inference over <em>all</em> the weights, but only over a subnetwork. In this case, the second step of the Laplace approximation involves inferring a full-covariance Gaussian posterior over only the subnetwork weights $\vw_S$:
\begin{equation}
p(\vw_S \cond \D) \approx q(\vw_S) = \Normal(\vw_S \cond \widehat\vw_S, \mH_S^{-1}).
\end{equation}
This is now tractable, since the subnetwork will in practice be substantially smaller than the full network, effectively giving us quadratic gains in space complexity and cubic gains in time complexity!</p>
<h2 id="q2-how-do-we-set-the-fixed-values-widehatmathbfw_r-of-all-remaining-weights-mathbfw_r_r-in-sc">Q2. How do we set the fixed values $\widehat{\mathbf{w}}_r$ of all remaining weights $\{\mathbf{w}_r\}_{r \in S^\c}$?</h2>
<p>In fact, this also answers Q2 of how to set the remaining weights not part of the subnetwork: Since the Laplace approximation requires us to first obtain a MAP estimate over all weights, it’s natural to simply leave all other weights at their MAP estimates!</p>
<p>Let’s now look at how subnetwork inference is done in practice.</p>
<h2 id="the-full-subnetwork-inference-algorithm">The full subnetwork inference algorithm</h2>
<p>Overall, our proposed subnetwork inference algorithm comprises the following four steps:</p>
<ol>
<li>Obtain a MAP estimate over all the weights of the neural net using standard optimisation methods such as SGD (see <a href="#figure-map">Figure 3</a>).</li>
</ol>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/a_map.png" alt="Step 1: Point estimation." id="figure-map" style="width: 100%; max-width: 300px" />
<p class="caption">
Figure 3: Step 1: Point estimation.
</p>
</div>
<ol start="2">
<li>Select a small subnetwork (see <a href="#figure-subnet">Figure 4</a>) — we’ll discuss in a second how this can be done in practice.</li>
</ol>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/b_subnet.png" alt="Step 2: Subnetwork selection." id="figure-subnet" style="width: 100%; max-width: 300px" />
<p class="caption">
Figure 4: Step 2: Subnetwork selection.
</p>
</div>
<ol start="3">
<li>Perform Bayesian inference just over the subnetwork (see <a href="#figure-inference">Figure 5</a>). As described above, we use the Laplace approximation to infer a full-covariance Gaussian over the subnetwork, and leave all other weights at their MAP estimates.</li>
</ol>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/c_inference.png" alt="Step 3: Bayesian inference." id="figure-inference" style="width: 100%; max-width: 300px" />
<p class="caption">
Figure 5: Step 3: Bayesian inference.
</p>
</div>
<ol start="4">
<li>Lastly, use the resulting mixed probabilistic–deterministic model to make predictions (see <a href="#figure-prediction">Figure 6</a>).</li>
</ol>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/d_prediction.png" alt="Step 4: Prediction." id="figure-prediction" style="width: 100%; max-width: 300px" />
<p class="caption">
Figure 6: Step 4: Prediction.
</p>
</div>
<p>Ok, now we know how to do inference over the subnetwork, but how do we find the subnetwork in the first place?</p>
<h2 id="q3-how-do-we-select-the-subnetwork-mathbfw_s-in-the-first-place">Q3. How do we select the subnetwork $\mathbf{w}_S$ in the first place?</h2>
<p>Recall that we want to preserve as much model uncertainty as possible with our subnetwork. A natural goal is therefore to find the subnetwork whose posterior is <em>closest</em> to the full network posterior. That is, we want to find the subset of weights that minimises some measure of discrepancy between the posterior over the full network and the posterior over the subnetwork.</p>
<p>To measure this discrepancy, we choose to use the Wasserstein distance:
\begin{align}
&\min \text{Wass}[\ \text{exact full posterior}\ |\ \text{subnetwork posterior}\ ] \nonumber \vphantom{\prod} \newline
&\qquad= \min \text{Wass}[\ p(\mathbf{w} \cond \mathcal{D})\ |\ q_S(\mathbf{w})\ ] \vphantom{\prod} \newline
&\qquad\approx \min \text{Wass}[\ \mathcal{N}\left(\mathbf{w}; \widehat{\mathbf{w}}, \mathbf{H}^{-1}\right)\ |\ \mathcal{N}(\mathbf{w}_S; \widehat{\mathbf{w}}_S, \mathbf{H}_S^{-1}) \prod_{r \in S^\c} \delta(\mathbf{w}_r - \widehat{\mathbf{w}}_r )\ ].
\end{align}
As the exact full network posterior $p(\mathbf{w} \cond \mathcal{D})$ is intractable, we here approximate it as a Gaussian $\mathcal{N}\left(\mathbf{w}; \widehat{\mathbf{w}}, \mathbf{H}^{-1}\right)$ over all weights (also estimated via the Laplace approximation). Also, as described earlier, the subnetwork posterior $q_S(\mathbf{w})$ is composed of a Gaussian $\mathcal{N}(\mathbf{w}_S; \widehat{\mathbf{w}}_S, \mathbf{H}_S^{-1})$ over the subnetwork and delta functions $\delta(\mathbf{w}_r - \widehat{\mathbf{w}}_r )$ over all other weights $\{\mathbf{w}_r\}_{r \in S^\c}$. Note that due to the delta functions, the subnetwork posterior is degenerate; this is why we use the Wasserstein distance, which remains well-defined for such degenerate distributions.</p>
<p>Unfortunately, this objective is still intractable, as it depends on all entries of the Hessian of the full network. To obtain a tractable objective, we assume that the full network posterior is factorised. By making this factorisation assumption, the Wasserstein objective now only depends on the diagonal entries of the Hessian, which are cheap to compute. I know what you’re thinking right now: “Didn’t they just tell us that the whole point of this subnetwork inference thing is to avoid making the assumption that the posterior is diagonal? And now they’re telling us that, actually, we still do have to make this assumption? This doesn’t make any sense!”</p>
<p>Well, in fact, it turns out that making the diagonal assumption <em>just for the purpose of subnetwork selection</em>, but then doing <em>full-covariance</em> Gaussian posterior inference over the subnetwork is much better than making the diagonal assumption for the purpose of inference itself (<em>i.e.</em> inference over the weights of the subnetwork and even over <em>all</em> weights), which we’ll see in the experiments later.</p>
<p>All in all, our proposed subnetwork selection procedure is as follows:</p>
<ol>
<li>Estimate a factorised Gaussian posterior over all weights, using for example a diagonal Laplace approximation.</li>
<li>Select those weights with the largest marginal variances. Why the weights with largest marginal variances? Well, one can show that, under the diagonal assumption, those are the weights that minimise the Wasserstein objective defined above.</li>
</ol>
<h2 id="q4-how-do-we-make-predictions-with-this-approximate-posterior">Q4. How do we make predictions with this approximate posterior?</h2>
<p>Great, we now know that a subnetwork can be found by (approximately) minimising the Wasserstein distance between the subnetwork posterior and the full network posterior. But how do we make predictions with this weird approximate posterior that is partly probabilistic and partly deterministic? We simply use all the weights of the neural net to make predictions: we integrate out the weights in the subnetwork, and just leave all other weights fixed at their MAP estimates. For integrating out the subnetwork weights, one can either use Monte Carlo or a closed-form approximation — please refer to the full paper for more details (the reference is given at the end of this blog post). Subnetwork inference therefore combines the strong predictive accuracy of the MAP estimate with the calibrated uncertainties of a Bayesian posterior.</p>
<p>Finally, we will now demonstrate the effectiveness of subnetwork inference in two experiments.</p>
<h2 id="q5-how-does-subnetwork-inference-perform-in-practice">Q5. How does subnetwork inference perform in practice?</h2>
<h3 id="experiment-1-how-does-subnetwork-inference-preserve-predictive-uncertainty">Experiment 1: How does subnetwork inference preserve predictive uncertainty?</h3>
<p>In the first experiment we train a small, 2-layer, fully-connected network with a total of 2600 weights on a 1D regression task, shown in <a href="#figure-regression">Figure 7</a>.</p>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/regression.png" alt="Predictive distributions (mean $\pm$ std) for 1D regression. The numbers in parentheses denote the number of parameters over which inference was done (out of 2600 in total). Wasserstein subnetwork inference maintains richer predictive uncertainties at smaller parameter counts." id="figure-regression" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 7: Predictive distributions (mean $\pm$ std) for 1D regression. The numbers in parentheses denote the number of parameters over which inference was done (out of 2600 in total). Wasserstein subnetwork inference maintains richer predictive uncertainties at smaller parameter counts.
</p>
</div>
<p>The number in brackets in the plot title denotes the number of weights over which we do inference; for example, for the MAP estimate (<a href="#figure-regression">Figure 7</a>, top left), inference was done over zero weights. As you can see, the 1D function we’re trying to fit consists of two separated clusters of data, and the goal here is to capture as much of the predictive uncertainty as possible, especially in-between those data clusters (<a href="https://arxiv.org/abs/1906.11537">Foong <em>et al.</em>, 2019</a>). As expected, the point estimate (<a href="#figure-regression">Figure 7</a>, top left) doesn’t capture any uncertainty, but instead makes confident predictions even in regions where there’s no data, which is bad.</p>
<p>On the other extreme, we can infer a full covariance Gaussian posterior over all the 2600 weights using a Laplace approximation (<a href="#figure-regression">Figure 7</a>, top middle), which is only tractable here due to the small model size. As we can see, a full-covariance Gaussian posterior is able to capture predictive uncertainty both at the boundaries and in-between the data clusters, so we will consider this to be the ideal, ground-truth posterior for this experiment.</p>
<p>Of course, in larger-scale settings, a full-covariance Gaussian would be intractable, so people often resort to diagonal approximations which assume full independence between the weights (<a href="#figure-regression">Figure 7</a>, top right). Unfortunately, as we can see, even though we do inference over all 2600 weights, due to the diagonal assumption we sacrifice a lot of the predictive uncertainty, especially in-between the two data clusters, where it’s only marginally better than the point estimate.</p>
<p>Now what about our proposed subnetwork inference method? First, let’s try doing full-covariance Gaussian inference over only 50% (that is, 1300) of the weights, found using the described Wasserstein minimisation approach (<a href="#figure-regression">Figure 7</a>, bottom left). As we can see, this approach captures predictive uncertainty much better than the diagonal posterior, and is even quite close to the full-covariance Gaussian over all weights. Well, but 50% is still quite a lot of weights, so let’s try to go even smaller, much smaller, to only 3% of the weights, which corresponds to just 78 weights here (<a href="#figure-regression">Figure 7</a>, bottom middle). Even then, we’re still much better off than the diagonal Gaussian. Let’s push this to the extreme, and estimate a full-covariance Gaussian over as little as 1% (that is, 26) of the weights (<a href="#figure-regression">Figure 7</a>, bottom right). Perhaps surprisingly, even with 1% of weights remaining, we do significantly better than the diagonal baseline, and are able to capture significant in-between uncertainty!</p>
<p>Overall, the take-away from this experiment is that doing expressive inference over a very small, but carefully chosen subnetwork, and capturing weight correlations just within that subnetwork can preserve more predictive uncertainty than a method that does inference over all the weights, but ignores weight correlations.</p>
<h3 id="experiment-2-how-robust-is-subnetwork-inference-to-distribution-shift">Experiment 2: How robust is subnetwork inference to distribution shift?</h3>
<p>Ok, 1D regression is fun, but we’re of course interested in scaling this to more realistic settings. In this second experiment, we consider the task of image classification under distribution shift. This task is much more challenging than 1D regression, so the model that we use is significantly larger than before: we use a ResNet-18 model with over 11 million weights, and, to remain tractable, we do inference over as little as 42 thousand (which is only around 0.38%) of the weights, again found using Wasserstein minimisation.</p>
<p>We consider five baselines: the MAP estimate, a diagonal Laplace approximation over all 11M weights, Monte Carlo (MC) dropout over all weights (<a href="https://arxiv.org/abs/1506.02142">Gal and Ghahramani, 2015</a>), Variational Online Gauss-Newton (short VOGN, <a href="https://arxiv.org/abs/1906.02506">Osawa <em>et al.</em>, 2019</a>), which estimates a factorised Gaussian over all weights, a Deep Ensemble (<a href="https://arxiv.org/abs/1612.01474">Lakshminarayanan <em>et al.</em>, 2017</a>) of 5 independently trained ResNet-18 models, and Stochastic Weight Averaging Gaussian (short SWAG, <a href="https://arxiv.org/abs/1902.02476">Maddox <em>et al.</em>, 2019</a>), which estimates a low-rank plus diagonal posterior over all weights. As another baseline, we also consider subnetwork inference with a <em>randomly selected subnetwork</em> (denoted <em>Ours (Rand)</em>), which will allow us to assess the impact of how the subnetwork is chosen.</p>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/benchmarks.png" alt="Example images from the (top) rotated MNIST and (bottom) corrupted CIFAR-10 benchmarks. (Top) An image of the digit 2 is increasingly rotated. (Bottom) An image of a dog is increasingly blurred." id="figure-benchmarks" style="width: 100%; max-width: 450px" />
<p class="caption">
Figure 8: Example images from the (top) rotated MNIST and (bottom) corrupted CIFAR-10 benchmarks. (Top) An image of the digit 2 is increasingly rotated. (Bottom) An image of a dog is increasingly blurred.
</p>
</div>
<p>We consider two benchmarks for evaluating robustness to distribution shift which were recently proposed by <a href="https://arxiv.org/abs/1906.02530">Ovadia <em>et al.</em> (2019)</a> (<a href="#figure-benchmarks">Figure 8</a>): firstly, we have rotated MNIST, where the model is trained on the standard MNIST training set, and then at test time evaluated on increasingly rotated MNIST digits (as for example shown for the digit 2 in <a href="#figure-benchmarks">Figure 8</a>, top); and secondly, we consider corrupted CIFAR-10, where we again train on the standard CIFAR-10 training set, but then evaluate on corrupted CIFAR-10 images; the test set contains over a dozen different corruption types, with five levels of increasing corruption severity (in this example, the image of a dog in <a href="#figure-benchmarks">Figure 8</a>, bottom, is getting more and more blurry from left to right).</p>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/mnist.png" alt="Results on the rotated MNIST benchmark, showing the mean $\pm$ std of the test error (top) and log-likelihood (bottom) across three different seeds. Subnetwork inference achieves better uncertainty calibration and robustness to distribution shift than point-estimated networks and other Bayesian deep learning approaches (except for VOGN), while retaining accuracy." id="figure-mnist" style="width: 100%; max-width: 450px" />
<p class="caption">
Figure 9: Results on the rotated MNIST benchmark, showing the mean $\pm$ std of the test error (top) and log-likelihood (bottom) across three different seeds. Subnetwork inference achieves better uncertainty calibration and robustness to distribution shift than point-estimated networks and other Bayesian deep learning approaches (except for VOGN), while retaining accuracy.
</p>
</div>
<p>Let’s start with rotated MNIST (<a href="#figure-mnist">Figure 9</a>). On the x-axis, we have the degree of rotation, and on the y-axis, we plot two different metrics: on top, we plot the test errors achieved by the different methods (where lower values are better), and on the bottom, we plot the corresponding log-likelihood, as a measure of calibration (where higher values are better). Here, we see that MAP, diagonal Laplace, MC dropout, the deep ensemble, SWAG, and the random subnetwork baseline all perform roughly similarly in terms of calibration (<a href="#figure-mnist">Figure 9</a>, bottom): their calibration becomes worse as we increase the degree of rotation; in contrast to that, subnetwork inference (shown in dark blue) remains much better calibrated, even at high degrees of rotation. The only competitive method here is VOGN, which slightly outperforms subnetwork inference in terms of calibration. Importantly, observe that this increase in robustness does <em>not</em> come at cost of accuracy (<a href="#figure-mnist">Figure 9</a>, top): Wasserstein subnetwork inference (as well as VOGN) retain the same accuracy as the other methods.</p>
<div class="image-container">
<img src="/blog/assets/images/subnetwork-inference/cifar10.png" alt="Results on the corrupted CIFAR-10 benchmark, showing the mean $\pm$ std of the test error (top) and log-likelihood (bottom) across three different seeds. Subnetwork inference achieves better uncertainty calibration and robustness to distribution shift than point-estimated networks and other Bayesian deep learning approaches, while retaining accuracy." id="figure-cifar10" style="width: 100%; max-width: 450px" />
<p class="caption">
Figure 10: Results on the corrupted CIFAR-10 benchmark, showing the mean $\pm$ std of the test error (top) and log-likelihood (bottom) across three different seeds. Subnetwork inference achieves better uncertainty calibration and robustness to distribution shift than point-estimated networks and other Bayesian deep learning approaches, while retaining accuracy.
</p>
</div>
<p>Now let’s look at corrupted CIFAR10 (<a href="#figure-cifar10">Figure 10</a>). There, the story is somewhat similar: we plot the corruption severity on the x-axis versus the error (<a href="#figure-cifar10">Figure 10</a>, top) and log-likelihood (<a href="#figure-cifar10">Figure 10</a>, bottom) on the y-axis. Here, MAP, diagonal Laplace, MC dropout and the random subnetwork baseline are all poorly calibrated (<a href="#figure-cifar10">Figure 10</a>, bottom). VOGN, SWAG and deep ensembles are a bit better calibrated, but are still significantly outperformed by subnetwork inference (again in dark blue), even at high corruption severities. Importantly, the improved robustness of Wasserstein subnetwork inference again does <em>not</em> compromise accuracy (<a href="#figure-cifar10">Figure 10</a>, top). In contrast, the accuracy of VOGN suffers on this dataset.</p>
<p>Overall, the take-away from this experiment is that subnetwork inference is better calibrated and therefore more robust to distribution shift than state-of-the-art baselines for uncertainty estimation in deep neural nets.</p>
<h2 id="take-home-message">Take-home message</h2>
<p>To conclude, in this blog post, we described subnetwork inference, which is a Bayesian deep learning method that does expressive inference over a carefully chosen subnetwork within a neural network. We also showed some empirical results suggesting that this works better than doing crude inference over the full network. There remain clear limitations of this work that deserve more investigation in the future: The most important one is to develop better subnetwork selection strategies that avoid the potentially restrictive approximations we use (<em>i.e.</em> the diagonal approximation to the posterior covariance matrix).</p>
<p>Thanks a lot for reading this blog post! If you want to learn more about this work, please feel free to check out our full ICML 2021 paper:</p>
<ul>
<li>Erik Daxberger, Eric Nalisnick, James Urquhart Allingham, Javier Antorán, José Miguel Hernández-Lobato. <a href="https://arxiv.org/abs/2010.14689">Bayesian Deep Learning via Subnetwork Inference</a>. In <em>ICML 2021</em>.</li>
</ul>
<p>Finally, we would like to thank Stratis Markou, Wessel Bruinsma and Richard Turner for many helpful comments on this blog post!</p>Erik DaxbergerBayesian inference has the potential to address shortcomings of deep neural networks (DNNs) such as poor calibration. However, scaling Bayesian methods to modern DNNs is challenging. This blog post describes subnetwork inference, a method that tackles this issue by doing inference over only a small, carefully selected subset of the DNN weights.Reinforcement Learning for 3D Molecular Design2021-04-30T00:00:00+00:002021-04-30T00:00:00+00:00https://mlg.eng.cam.ac.uk/blog/2021/04/30/reinforcement-learning-for-3d-molecular-design<p>Imagine we were able to design molecules with exactly the properties we care about. This would unlock huge potential for applications such as de novo drug design and materials discovery. Unfortunately, searching for particular chemical compounds is like trying to find the needle in a haystack: <a href="https://doi.org/10.1007/s10822-013-9672-4">Polishchuk <em>et al.</em> (2013)</a> estimate that there are between $10^{30}$ and $10^{60}$ feasible and potentially drug-like molecules, making exhaustive search hopeless. Worse yet, we don’t even know what the needle looks like.</p>
<p>In this blog post, we will outline how we combine ideas from reinforcement learning and quantum chemistry to catalyse the search for new molecules. We will explain how we can push the boundaries of the type of molecules we can build by representing the atoms directly in Cartesian coordinates. Finally, we will demonstrate how we can exploit symmetries of the design process to efficiently train a reinforcement learning agent for molecular design.</p>
<h2 id="re-framing-molecular-design-using-reinforcement-learning-and-quantum-chemistry">Re-framing Molecular Design using Reinforcement Learning and Quantum Chemistry</h2>
<p>To be able to design general molecular structures, it is critical to choose the right representation. Most approaches rely on graph representations of molecules, where atoms and bonds are represented by nodes and edges, respectively. However, this is a strongly simplified model designed for the description of single organic molecules. It is unsuitable for encoding metals and molecular clusters as it lacks information about the relative position of atoms in 3D space. Further, geometric constraints on the molecule cannot be easily encoded in the design process. A more general representation closer to the physical system is one in which a molecule is described by its atoms’ positions in Cartesian coordinates. We therefore directly work in this space.</p>
<p>In particular, we design molecules by sequentially drawing atoms from a given bag and placing them onto a 3D canvas. The canvas $\mathcal{C}$ contains all atoms $(e, x)$ with element $e \in \{\ce{H}, \ce{C}, \ce{N}, \ce{O}\, \dots \}$ and position $x \in \mathbb{R}^3$ placed so far, whereas the bag $\mathcal{B}$ comprises atoms still to be placed. We formulate this task as a sequential decision-making problem in a Markov decision process, where the agent is rewarded for building stable molecules. At the beginning of each episode, the agent receives an initial state $s_0 = (\mathcal{C}_{0}, \mathcal{B}_0)$, <em>e.g.</em> $\mathcal{C}_0 = \emptyset$ and $\mathcal{B}_0 = \ce{SFO_4}$ (see <a href="#figure-env">Figure 1</a>). At each timestep $t$, the agent draws an atom from the bag and places it onto the canvas through action $a_t$, yielding reward $r_t$ and transitioning the environment into state $s_{t+1}$. This process is repeated until the bag is empty.</p>
<div class="image-container">
<img src="/blog/assets/images/molgym/env.png" alt="We build a molecule by repeatedly taking atoms from bag $\mathcal{B}_0 = \ce{SOF_4}$ and placing them onto the 3D canvas. Bonds connecting atoms are only for illustration." id="figure-env" style="width: 100%; max-width: 650px" />
<p class="caption">
Figure 1: We build a molecule by repeatedly taking atoms from bag $\mathcal{B}_0 = \ce{SOF_4}$ and placing them onto the 3D canvas. Bonds connecting atoms are only for illustration.
</p>
</div>
<p>An advantage of designing molecules in Cartesian space is that we can evaluate states in terms of quantum-mechanical properties.<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> Here, the reward function encourages the agent to design stable molecules as measured in terms of their energy; however, linear combinations of other desirable properties (like drug-likeliness or toxicity) would also be possible. We define the reward as the negative difference in energy between the resulting molecule described by $\mathcal{C}_{t+1}$ and the sum of energies of the current molecule $\mathcal{C}_t$ and a new atom of element $e_t$,
\begin{equation}
r(s_t, a_t) = \left[E(\mathcal{C}_t) + E(e_t)\right] - E(\mathcal{C}_{t+1}),
\end{equation}
where $E(e) := E({e, [0,0,0]^T })$. We compute the energy using a fast semi-empirical quantum-chemical method. Importantly, the episodic return for building a molecule does not depend on the order in which atoms are placed.</p>
<h2 id="exploiting-symmetry-using-internal-coordinates-simm-et-al-2020">Exploiting Symmetry using Internal Coordinates (<a href="http://proceedings.mlr.press/v119/simm20b.html">Simm et al., 2020</a>)</h2>
<p>Learning to place atoms in Cartesian coordinates requires that the agent exploits the symmetries of the molecular design process. Therefore, we need a policy $\pi(a_t \vert s_t)$ that is <em>covariant</em><sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup> to translation and rotation. In other words, if the canvas is rotated or translated, the position $x$ of the atom to be placed should be rotated and translated as well.</p>
<p>To achieve this, we first encode the current state $s_t$ into an invariant representation $s^\text{inv} = \mathsf{SchNet}(s_t)$, where $\mathsf{SchNet}$ (<a href="https://proceedings.neurips.cc/paper/2017/hash/303ed4c69846ab36c2904d3ba8573050-Abstract.html">Schütt <em>et al.</em>, 2017</a>; <a href="https://doi.org/10.1063/1.5019779">Schütt <em>et al.</em>, 2018</a>) is a neural network architecture that models interactions between atoms. Given $s^\text{inv}$, our agent selects 1) a focal atom $f$ among already placed atoms that acts as a reference point, 2) an available element $e$ from the bag for the atom to be placed, and 3) the position of the atom to be placed in internal coordinates (see <a href="#figure-internal_agent">Figure 2</a>). These coordinates consist of the distance $d$ to the focal atom as well as angles $\alpha$ and $\psi$ with respect to the focal atom and its neighbours. Finally, we obtain a position $x$ that features the required covariance by mapping the internal coordinates back to Cartesian coordinates. We call the resulting agent $\mathsf{Internal}$.</p>
<div class="image-container">
<img src="/blog/assets/images/molgym/internal_agent.png" alt="The $\mathsf{Internal}$ agent places an atom from the bag (highlighted in orange) relative to the focal atom (highlighted in purple), where the internal coordinates $(d, \alpha, \psi)$ uniquely determine its absolute position." id="figure-internal_agent" style="width: 100%; max-width: 500px" />
<p class="caption">
Figure 2: The $\mathsf{Internal}$ agent places an atom from the bag (highlighted in orange) relative to the focal atom (highlighted in purple), where the internal coordinates $(d, \alpha, \psi)$ uniquely determine its absolute position.
</p>
</div>
<h2 id="so-does-it-work">So… does it work?</h2>
<p>Equipped with a policy, we can finally design some molecules! To demonstrate how the $\mathsf{Internal}$ agent works, we separately train it on the bags $\ce{CH_3N_O}, \ce{CH_4O}$ and $\ce{C_2H_2O_2}$ using PPO (<a href="https://arxiv.org/abs/1707.06347">Schulman <em>et al.</em>, 2017</a>). <a href="#figure-singles">Figure 3</a> shows that the agent is able to learn interatomic distances as well as the rules of chemical bonding from scratch. On average, the agent reaches $90\%$ of the optimal return<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup> after only $12\,000$ steps. However, from the snapshots $\enclose{circle}{2}$ and $\enclose{circle}{3}$ in <a href="#figure-singles">Figure 3</a> (b) we can see that the generated structures are not quite optimal yet. While the policy has mostly learned the atomic distances, it still has to figure out the angles between atoms. After training the policy for a bit longer, at point $\enclose{circle}{4}$ we finally generate valid, stable molecules. It works!</p>
<div class="image-container">
<img src="/blog/assets/images/molgym/singles.png" alt="(a) The $\mathsf{Internal}$ agent is able to build stable molecules from the bags $\ce{CH3NO}, \ce{CH4O}$ and $\ce{C2H2O2}$. Each dashed line denotes the optimal return for the corresponding bag.
(b) Generated molecular structures at different terminal states over time show the agent's learning progress." id="figure-singles" style="width: 100%; max-width: 800px" />
<p class="caption">
Figure 3: (a) The $\mathsf{Internal}$ agent is able to build stable molecules from the bags $\ce{CH3NO}, \ce{CH4O}$ and $\ce{C2H2O2}$. Each dashed line denotes the optimal return for the corresponding bag.
(b) Generated molecular structures at different terminal states over time show the agent's learning progress.
</p>
</div>
<p>While these results look promising, the $\mathsf{Internal}$ agent actually struggles when faced with highly symmetric structures. As shown in <a href="#figure-fail">Figure 4</a>, that is because the choice of angles $\alpha$ and $\psi$ used in the internal coordinates can become ambiguous in such cases. A better approach would be to directly generate a rotation-covariant orientation $\tilde{x}$ of the atom to be placed without going through these internal coordinates.</p>
<div class="image-container">
<img src="/blog/assets/images/molgym/fail.png" alt="Example of two configurations (a) and (b) that the $\mathsf{Internal}$ agent cannot distinguish. While the values for distance $d$, and angles $\alpha$ and $\psi$ are the same, choosing different reference points (in red) leads to a different action. This is particularly problematic in symmetric states, where one cannot uniquely determine these reference points." id="figure-fail" style="width: 100%; max-width: 450px" />
<p class="caption">
Figure 4: Example of two configurations (a) and (b) that the $\mathsf{Internal}$ agent cannot distinguish. While the values for distance $d$, and angles $\alpha$ and $\psi$ are the same, choosing different reference points (in red) leads to a different action. This is particularly problematic in symmetric states, where one cannot uniquely determine these reference points.
</p>
</div>
<h2 id="exploiting-symmetry-using-spherical-harmonics-simm-et-al-2021">Exploiting Symmetry using Spherical Harmonics (<a href="https://openreview.net/forum?id=jEYKjPE1xYN">Simm et al., 2021</a>)</h2>
<p>Therefore, we replace the angles $\alpha$ and $\psi$ by directly sampling the orientation from a distribution on a sphere with radius $d$ centered at the focal atom $f$ (see <a href="#figure-covariant_agent">Figure 5</a>). We can define such a distribution using <em>spherical harmonics</em>, which are essentially basis functions defined on the sphere. In particular, we are able to model any (multi-modal) distribution on the sphere by generating the right coefficients $\hat{r}$ for these basis functions. To produce the coefficients, we modify $\mathsf{Cormorant}$ (<a href="https://papers.nips.cc/paper/2019/hash/03573b32b2746e6e8ca98b9123f2249b-Abstract.html">Anderson <em>et al.</em>, 2019</a>), a neural network architecture for predicting properties of chemical systems that works entirely in Fourier space. Finally, we can sample a rotation-covariant orientation $\tilde{x}$ from the spherical distribution defined by $\hat{r}$ using rejection sampling. We call the resulting agent $\mathsf{Covariant}$.</p>
<div class="image-container">
<img src="/blog/assets/images/molgym/covariant_agent.png" alt="The $\mathsf{Covariant}$ agent chooses focal atom $f$, element $e$, distance $d$, and orientation $\tilde{x}$. We then map back to global coordinates $x$ to obtain action $a_t = (e, x)$." id="figure-covariant_agent" style="width: 100%; max-width: 650px" />
<p class="caption">
Figure 5: The $\mathsf{Covariant}$ agent chooses focal atom $f$, element $e$, distance $d$, and orientation $\tilde{x}$. We then map back to global coordinates $x$ to obtain action $a_t = (e, x)$.
</p>
</div>
<p>To verify that $\mathsf{Covariant}$ works as expected, we compare it to the previous $\mathsf{Internal}$ agent on structures with high symmetry and coordination numbers. As shown in <a href="#figure-complexes">Figure 6</a> (a), $\mathsf{Covariant}$ is able to build valid molecules from the bags $\ce{SOF_4}$ and $\ce{IF_5}$ within $30\,000$ to $40\,000$ steps, whereas $\mathsf{Internal}$ fails to build low-energy configurations as it cannot distinguish highly symmetric intermediates. Further results in <a href="#figure-complexes">Figure 6</a> (b) for $SOF_6$ and $SF_6$ show that $\mathsf{Covariant}$ is capable of building such structures. While the constructed molecules are small in size, recall that they would be unattainable with graph-based methods as they lack important 3D information.</p>
<div class="image-container">
<img src="/blog/assets/images/molgym/complexes.png" alt="(a) The $\mathsf{Covariant}$ agent succeeds in building stable molecules from the bags $\ce{SOF4}$ (left) and $\ce{IF5}$ (right). In contrast, $\mathsf{Internal}$ fails as it cannot distinguish highly symmetric structures. In the lower right, you can see molecular structures generated by the agents. Dashed lines denote the optimal return for each experiment.
(b) Further molecular structures generated by $\mathsf{Covariant}$, namely $\ce{SOF6}$ and $\ce{SF6}$." id="figure-complexes" style="width: 100%; max-width: 800px" />
<p class="caption">
Figure 6: (a) The $\mathsf{Covariant}$ agent succeeds in building stable molecules from the bags $\ce{SOF4}$ (left) and $\ce{IF5}$ (right). In contrast, $\mathsf{Internal}$ fails as it cannot distinguish highly symmetric structures. In the lower right, you can see molecular structures generated by the agents. Dashed lines denote the optimal return for each experiment.
(b) Further molecular structures generated by $\mathsf{Covariant}$, namely $\ce{SOF6}$ and $\ce{SF6}$.
</p>
</div>
<h2 id="thats-it--a-first-step-towards-general-molecular-design-in-cartesian-coordinates">That’s it — a first step towards general molecular design in Cartesian coordinates</h2>
<p>Let’s end here for now, even though we were really just getting started. To summarise, we have presented a novel reinforcement learning formulation for 3D molecular design guided by quantum chemistry. The key insight to get it to work was to exploit the symmetries of the design process, particularly using spherical harmonics.</p>
<p>One aspect we didn’t show today is how flexible this framework actually is. For example, we can use it to learn across multiple bags at the same time and generalise (to some extent) to unseen bags. Of course we can also scale up to larger molecules, though not quite as large as graph-based methods yet. Finally, we can even build molecular clusters, <em>e.g.</em> to model solvation processes. If that sounds interesting to you, make sure to check out the full papers:</p>
<ol>
<li><a href="http://proceedings.mlr.press/v119/simm20b.html">Simm <em>et al.</em> (2020)</a>. Reinforcement Learning for Molecular Design Guided by Quantum Mechanics. ICML 2020.</li>
<li><a href="https://openreview.net/forum?id=jEYKjPE1xYN">Simm <em>et al.</em> (2021)</a>. Symmetry-Aware Actor-Critic for 3D Molecular Design. ICLR 2021.</li>
</ol>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>In contrast, graph-based approaches have to resort to heuristic reward functions. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p>More precisely, only the position $x$ needs to be covariant, whereas the element $e$ has to be invariant. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:3" role="doc-endnote">
<p>We estimate the optimal return by using structural optimisation techniques to obtain the optimal structure and its energy. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Robert PinslerAutomating the design of molecules with desirable properties can greatly accelerate the search for novel drugs and materials. However, to make further progress we need to go beyond graph-based approaches. In this blog post, we use ideas from reinforcement learning and quantum chemistry to make a first step towards 3D molecular design.Natural-Gradient Variational Inference 1: The Maths2021-04-13T00:00:00+00:002021-04-13T00:00:00+00:00https://mlg.eng.cam.ac.uk/blog/2021/04/13/ngvi-bnns-part-1<p>Bayesian Deep Learning hopes to tackle neural networks’ poorly-calibrated uncertainties by injecting some level of Bayesian thinking.
There has been mixed success: progress is difficult as scaling Bayesian methods to such huge models is difficult!
One promising direction of research is based on natural-gradient variational inference.
We shall motivate and derive such algorithms, and then analyse their performance at a large scale, such as on ImageNet.</p>
<p>This is the first part of a two-part blog.
This first part will involve quite a lot of detailed maths: we will derive a natural-gradient variational inference (NGVI) algorithm that can run on neural networks (NNs).
We will follow the appendices in <a href="https://arxiv.org/pdf/1806.04854.pdf">Khan et al. (2018)</a>.
NGVI algorithms are in contrast to stochastic gradient algorithms such as Bayes-By-Backprop (<a href="https://arxiv.org/pdf/1505.05424.pdf">Blundell et al., 2015</a>), which also optimises the same Bayesian VI objective function, and also in contrast to Adam and SGD, which optimise for a non-Bayesian estimate of neural network weights.</p>
<p>In the <a href="https://mlg-blog.com/2021/11/24/ngvi-bnns-part-2.html">second part of the blog</a>, we will work our way to large datasets/architectures such as ImageNet/ResNets, discussing additional approximations required, as well as analysing their promising results. The second part will closely follow a paper I was involved in, <a href="https://arxiv.org/pdf/1906.02506.pdf">Osawa et al. (2019)</a>.</p>
<p>I hope to leave the reader with an understanding of how NGVI algorithms for NNs are derived, and some intuition for their strengths and weaknesses.
I will <strong>not</strong> discuss other Bayesian neural network algorithms, nor get involved in the recent debates over what it means to be ‘Bayesian’ in deep learning!
$\newcommand{\vparam}{\boldsymbol{\theta}}$
$\newcommand{\veta}{\boldsymbol{\eta}}$
$\newcommand{\vphi}{\boldsymbol{\phi}}$
$\newcommand{\vmu}{\boldsymbol{\mu}}$
$\newcommand{\vSigma}{\boldsymbol{\Sigma}}$
$\newcommand{\vm}{\mathbf{m}}$
$\newcommand{\vF}{\mathbf{F}}$
$\newcommand{\vI}{\mathbf{I}}$
$\newcommand{\vg}{\mathbf{g}}$
$\newcommand{\vH}{\mathbf{H}}$
$\newcommand{\vs}{\mathbf{s}}$
$\newcommand{\myexpect}{\mathbb{E}}$
$\newcommand{\pipe}{\,|\,}$
$\newcommand{\data}{\mathcal{D}}$
$\newcommand{\loss}{\mathcal{L}}$
$\newcommand{\gauss}{\mathcal{N}}$</p>
<h2 id="why-natural-gradient-variational-inference">Why natural-gradient variational inference?</h2>
<p>If you are reading this blog, hopefully you already know about Bayesian inference and its many promises when combined with deep learning: in short, we hope to obtain reliable confidence estimates, avoid overfitting on small datasets, and deal naturally with sequential learning.
But exact Bayesian inference on large models such as neural networks is intractable.</p>
<p>Although there are many approximate Bayesian inference algorithms, we will only focus on variational inference. <a href="https://arxiv.org/pdf/1505.05424.pdf">Blundell et al. (2015)</a> introduced Bayes-By-Backprop for training NNs with VI. But this has been very difficult to scale to large NNs such as ResNets: the main problem is that optimisation is restrictively slow as it requires many passes through the data.</p>
<p>Separately, natural-gradient update steps were introduced as a principled way of incorporating the information geometry of the distribution being optimised (<a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.452.7280&rep=rep1&type=pdf">Amari, 1998</a>).
By incorporating the geometry of the distribution, we expect to take gradient steps in much better directions.
This should speed up gradient optimisation significantly.
For a more detailed explanation, please look at the motivation in papers such as <a href="https://arxiv.org/pdf/1807.04489.pdf">Khan & Nielsen (2018)</a> or <a href="https://jmlr.org/papers/volume21/17-678/17-678.pdf">Martens (2020)</a>; I found figures such as Figure 1(a) from <a href="https://arxiv.org/pdf/1807.04489.pdf">Khan & Nielsen (2018)</a> particularly useful.</p>
<p>It therefore makes sense to try and apply natural-gradient updates to VI for NNs, where speed of convergence has been an issue.
In this blog post, we do this while looking closely at the mathematical details.
We will follow the appendices in <a href="https://arxiv.org/pdf/1806.04854.pdf">Khan et al. (2018)</a> (there is a slightly different derivation in <a href="https://arxiv.org/pdf/1712.02390.pdf">Zhang et al. (2018)</a>).
I also hope that, after reading this blog, you will be able to confidently approach recent papers that use NGVI, papers which often assume some knowledge of how NGVI algorithms are derived.</p>
<h2 id="starting-with-the-basics">Starting with the basics</h2>
<p>This section is a very brief overview of some fundamental concepts we will need.
If you understand Equations \eqref{eq:exp_fam}, \eqref{eq:ELBO}, \eqref{eq:simple_NGD} & \eqref{eq:NGD}, then feel free to skip the text. If anything is unfamiliar, there will be links to some good references.</p>
<h3 id="exponential-families">Exponential Families</h3>
<p>Exponential family distributions are commonly used in machine learning, with some key properties we can use. They include Gaussian distributions, which is the specific case we will consider later. Exponential family distributions are covered in most machine learning courses, and there are many good references, such as <a href="https://probml.github.io/pml-book/book1.html">Murphy (2021)</a> and <a href="https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf">Bishop (2006)</a>.</p>
<p>An exponential family distribution over parameters $\vparam$ with natural parameters $\veta$ has the form,</p>
<p>\begin{align} \label{eq:exp_fam}
q(\vparam|\veta) = q_{\veta}(\vparam) = h(\vparam)\exp [ \langle\veta,\vphi(\vparam)\rangle - A(\veta) ],
\end{align}</p>
<p>where $\vphi(\vparam)$ is the vector of sufficient statistics, $\langle \cdot,\cdot \rangle$ is an inner product, $A(\veta)$ is the log-partition function and $h(\vparam)$ is a scaling constant. We also assume a <em>minimal</em> exponential family, when the sufficient statistics are linearly independent. This means that there is a one-to-one mapping between $\veta$ and the mean parameters $\vm = \myexpect_{q_\veta(\vparam)} [\vphi(\vparam)]$, and that $\vm = \nabla_\veta A(\veta)$.</p>
<h3 id="variational-inference-vi">Variational inference (VI)</h3>
<p>In Bayesian inference, we want to learn the posterior distribution over parameters after observing some data $\data$. The posterior is given as,</p>
<p>\begin{equation}
p(\vparam \cond \data) = \frac{ {\color{purple}p(\data\cond\vparam)} {\color{blue}p_0(\vparam)}}{p(\data)}, \nonumber
\end{equation}</p>
<p>where ${\color{purple}p(\data\pipe \vparam)}$ is the data likelihood and ${\color{blue}p_0(\vparam)}$ is the prior over parameters. We will use colours to keep track of terms coming from the likelihood and prior.
Note that in supervised learning, where the dataset $\data$ consists of inputs $\mathbf{X}$ and labels $\mathbf{y}$, we should write the likelihood as ${\color{purple}p(\mathbf{y}\pipe \mathbf{X}, \vparam)}$, but we slightly abuse notation by writing ${\color{purple}p(\data\pipe \vparam)}$.</p>
<p>If our likelihood and prior are set correctly, then exact Bayesian inference is optimal, but unfortunately there are problems in reality (this statement comes with many caveats! See e.g. <a href="https://mlg-blog.com/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-1.html">this blog post</a> for a more detailed discussion).
To name two problems, (i) we are usually unsure if the likelihood or prior is correct, and (ii) exact Bayesian inference is often not possible, especially in NNs. In this blog post, we only focus on approaches to problem (ii): algorithms for approximate Bayesian inference. We do not consider problem (i).</p>
<p>Variational Bayesian inference approximates exact Bayesian inference by learning the parameters of a distribution $q(\vparam)$ that best approximates the true posterior distribution $p(\vparam \pipe \data)$. We do this by maximising the Evidence Lower Bound (ELBO), which is equivalent to minimising the KL divergence between the approximate distribution and the true posterior.
By assuming that $q(\vparam)$ is an exponential family distribution $q_\veta(\vparam)$, we can write the ELBO as follows,</p>
<p>\begin{equation} \label{eq:ELBO}
\loss_\mathrm{ELBO}(\veta) = \underbrace{\myexpect_{q_\veta(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]}_\text{Likelihood term} + \underbrace{\myexpect_{q_\veta(\vparam)} \left[\log \frac{ {\color{blue}p_0(\vparam)}}{q_\veta(\vparam)} \right]}_\text{KL (to prior) term},
\end{equation}</p>
<p>which we optimise with respect to $\veta$. There are many good references on variational inference, such as <a href="https://arxiv.org/pdf/1601.00670.pdf">Blei et al. (2018)</a> or <a href="https://arxiv.org/pdf/1711.05597.pdf">Zhang et al. (2018)</a>.</p>
<h3 id="natural-gradient-ng-updates">Natural-gradient (NG) updates</h3>
<p>Let’s say that we have some function $\loss(\veta)$ that we want to optimise with respect to the parameters of an exponential family distribution $\veta$. Later in this blog post, this function will be the ELBO. Natural-gradient updates take gradient steps as follows until convergence,</p>
<p>\begin{equation} \label{eq:simple_NGD}
\veta_{t+1} = \veta_t + \beta_t \vF(\veta_t)^{-1}\nabla_\veta \loss(\veta_t),
\end{equation}</p>
<p>where $\nabla_\veta \loss(\veta_t) = \nabla_\veta \loss(\veta) \pipe_{\veta=\veta_t}$,</p>
<p>\begin{equation*}
\vF(\veta_t) = \myexpect_{q_\veta(\vparam)} \left[ \nabla_\veta \log q_\veta(\vparam) \nabla_\veta \log q_\veta(\vparam)^\top \right],
\end{equation*}</p>
<p>is the Fisher information matrix, and $\beta_t$ is a learning rate.
As previously discussed, natural-gradient methods incorporate the information geometry of the distribution being optimised (through the Fisher information matrix), and therefore reduce the number of gradient steps required.
Some good references include <a href="http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.452.7280&rep=rep1&type=pdf">Amari (1998)</a> and <a href="https://arxiv.org/pdf/1412.1193.pdf">Martens (2014)</a>.</p>
<p>We can use a neat trick of exponential families to simplify the update step and side-step having to compute and invert the Fisher matrix directly (see e.g. <a href="http://www.columbia.edu/~jwp2128/Papers/HoffmanBleiWangPaisley2013.pdf">Hoffman et al. (2013)</a> or <a href="https://arxiv.org/pdf/1703.04265.pdf">Khan & Lin (2017)</a>).
One way to show this is to note that</p>
<p>\begin{equation} \label{eq:mean-natural gradient}
\nabla_\veta \loss(\veta_t) = [\nabla_\veta \vm_t] \nabla_\vm \loss_*(\vm_t) = [\nabla^2_{\veta\veta}A(\veta_t)] \nabla_\vm \loss_*(\vm_t) = \vF(\veta_t) \nabla_\vm \loss_*(\vm_t),
\end{equation}</p>
<p>where $\loss_*(\vm)$ is the same function as $\loss(\veta)$ except written in terms of the mean parameters $\vm$.
We have used the fact that $\vF(\veta) = \nabla^2_{\veta\veta}A(\veta)$, please see earlier references for this.</p>
<p>Plugging this in, we get our simplified natural-gradient update step,</p>
<p>\begin{equation} \label{eq:NGD}
\veta_{t+1} = \veta_t + \beta_t \nabla_\vm \loss_*(\vm_t).
\end{equation}</p>
<h2 id="the-details-natural-gradient-vi">The details: Natural-gradient VI</h2>
<p>We wish to combine variational inference with natural-gradient updates, so let’s get straight into it: we plug the ELBO (Equation \eqref{eq:ELBO}) into the natural-gradient update (Equation \eqref{eq:NGD}). Let the prior ${\color{blue}p_0(\vparam)}$ be an exponential family with natural parameters ${\color{blue}\veta_0}$. We first note that the KL term in the ELBO can be simplified as we are dealing with exponential families,</p>
<p>\begin{align}
\nabla_\vm \,\text{KL term}
&= \nabla_\mathbf{m} \mathbb{E}_{q_\eta(\boldsymbol{\theta})} \left[ \boldsymbol{\phi}(\boldsymbol{\theta})^\top ({\color{blue}\veta_0} - \veta) + A(\veta) + \text{const} \right] \nonumber\newline
&= \nabla_\mathbf{m} \left[ \mathbf{m}^\top ({\color{blue}\veta_0} - \veta) \right] + \nabla_\mathbf{m} A(\veta) \nonumber\newline
&= {\color{blue}\veta_0} - \veta - \left[ \nabla_\mathbf{m}\veta \right]^\top \mathbf{m} + \nabla_\mathbf{m} A(\veta) \nonumber\newline
&= {\color{blue}\veta_0} - \veta - \mathbf{F}(\veta)^{-1}\mathbf{m} + \mathbf{F}(\veta)^{-1}\mathbf{m} \nonumber\newline
&= {\color{blue}\veta_0} - \veta. \nonumber
\end{align}</p>
<p>The third line follows using the product rule, and the fourth line uses $\nabla_\mathbf{m}(\cdot) = \mathbf{F}(\veta)^{-1} \nabla_\veta(\cdot)$ from Equation \eqref{eq:mean-natural gradient} and the symmetry of the Fisher information matrix.
Plugging the ELBO (with this simplification) into Equation \eqref{eq:NGD},</p>
<p>\begin{align}
\veta_{t+1} &= \veta_t + \beta_t \left( \nabla_\vm \myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right] + ({\color{blue}\veta_0} - \veta_t) \right) \nonumber\newline
\label{eq:BLR}
\therefore \veta_{t+1} &= (1-\beta_t) \veta_t + \beta_t \Big({\color{blue}\veta_0} + \nabla_\vm \underbrace{\myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]}_{ {\color{purple}\Large\mathcal{F}_t}} \Big).
\end{align}</p>
<p>This equation is presented and analysed in detail in <a href="https://arxiv.org/pdf/2107.04562.pdf">Khan & Rue (2021)</a>, where it is called the ‘Bayesian learning rule’.
I recommend reading the paper if you are interested in this and related topics: they show how this equation appears in many different scenarios (beyond just the Bayesian derivation presented above), and also consider extensions beyond what we consider in this blog post. This allows them to connect to a plethora of different learning algorithms ranging from Newton’s method to Kalman filters to Adam.</p>
<h3 id="gaussian-approximating-family">Gaussian approximating family</h3>
<p>We now consider a Gaussian approximating family, $q_\veta(\vparam) = \gauss(\vparam; \vmu, \vSigma)$.
We will substitute the Gaussian’s natural parameters into Equation \eqref{eq:BLR} to obtain updates for $\vmu$ and $\vSigma$.
The minimal representation for a Gaussian family has two components to its natural parameters and mean parameters,</p>
<p>\begin{align}
\veta^{(1)} &= \vSigma^{-1}\vmu,
& \veta^{(2)} &= -\frac{1}{2}\vSigma^{-1}, \nonumber\newline
\vm^{(1)} &= \vmu,
& \vm^{(2)} &= \vmu\vmu^\top + \vSigma. \nonumber
\end{align}</p>
<p>Let the prior be a zero-mean Gaussian, ${\color{blue}p_0(\vparam) = \gauss(\vparam; \boldsymbol{0}, \delta^{-1}\vI)}$. We can therefore write the prior natural parameters as ${\color{blue}\veta_0^{(1)} = \boldsymbol{0}, \veta_0^{(2)} = -\frac{1}{2}\delta\vI}.$</p>
<p>We now simplify $\nabla_\vm {\color{purple}\mathcal{F}_t}$ to be in terms of $\vmu$ and $\vSigma$ instead of $\vm$. We can use the chain rule to do this (see e.g. <a href="http://www0.cs.ucl.ac.uk/staff/c.archambeau/publ/neco_mo09_web.pdf">Opper & Archambeau (2009)</a> or Appendix B.1 in <a href="https://arxiv.org/pdf/1703.04265.pdf">Khan & Lin, 2017</a>),</p>
<p>\begin{align}
\nabla_{\vm^{(1)}}{\color{purple}\mathcal{F}_t} &= \nabla_\vmu {\color{purple}\mathcal{F}_t} - 2[\nabla_\vSigma {\color{purple}\mathcal{F}_t}] \vmu, \nonumber\newline
\nabla_{\vm^{(2)}}{\color{purple}\mathcal{F}_t} &= \nabla_\vSigma {\color{purple}\mathcal{F}_t}. \nonumber
\end{align}</p>
<p>We would like to write out the natural-gradient updates (Equation \eqref{eq:BLR}) for the parameters of a Gaussian, with the resulting equations in terms of the prior natural parameters ${\color{blue}\veta_0}$ and the data ${\color{purple}\mathcal{F}_t}$.
So let’s substitute the above derivations into Equation \eqref{eq:BLR}. We start with the second element, $\veta^{(2)}$, giving us an update for $\vSigma^{-1}$,</p>
<p>\begin{equation} \label{eq:Gaussian_Sigma}
\vSigma_{t+1}^{-1} = (1-\beta_t)\vSigma_t^{-1} + \beta_t ({\color{blue}\delta\vI} - 2\nabla_\vSigma {\color{purple}\mathcal{F}_t}).
\end{equation}</p>
<p>We also obtain an update for the mean $\vmu$ by looking at the first element $\veta^{(1)}$,</p>
<p>\begin{align}
\vSigma_{t+1}^{-1} \vmu_{t+1} &= (1-\beta_t)\vSigma_{t}^{-1} \vmu_{t} + \beta_t ({\color{blue}\boldsymbol{0}} + \nabla_\vmu {\color{purple}\mathcal{F}_t} - 2 [\nabla_\vSigma {\color{purple}\mathcal{F}_t}] \vmu_t) \nonumber\newline
&= \underbrace{\left[ (1-\beta_t)\vSigma_t^{-1} + \beta_t ({\color{blue}\delta\vI} - 2\nabla_\vSigma {\color{purple}\mathcal{F}_t}) \right]}_{=\vSigma_{t+1}^{-1}\text{, by Equation \eqref{eq:Gaussian_Sigma}}} \vmu_t + \beta_t (\nabla_\vmu {\color{purple}\mathcal{F}_t} - {\color{blue}\delta} \vmu_t) \nonumber\newline
\label{eq:Gaussian_mu}
\therefore \vmu_{t+1} &= \vmu_t + \beta_t \vSigma_{t+1} (\nabla_\vmu {\color{purple}\mathcal{F}_t} - {\color{blue}\delta} \vmu_t).
\end{align}</p>
<p>The update for the precision $\vSigma^{-1}$ (Equation \eqref{eq:Gaussian_Sigma}) is a moving average update, and the precision slowly gets closer to and tracks $({\color{blue}\delta\vI} - 2\nabla_\vSigma {\color{purple}\mathcal{F}_t})$.
The update for the mean $\vmu$ (Equation \eqref{eq:Gaussian_mu}) is very similar to an update for a (stochastic) gradient update for the mean. A key difference is the additional $\vSigma_{t+1}$ term, which (loosely speaking!) is the ‘natural-gradient’ part of the update: it automatically determines the learning rate for different elements of the mean $\vmu$.
In the second part of the blog post, we will see how this relates to other algorithms such as Adam, which also try and automatically determine learning rates using data.</p>
<h3 id="variational-online-newton-algorithm-von">Variational Online-Newton algorithm (VON)</h3>
<p>We are now very close to a complete NGVI algorithm.
We just need to deal with the $\nabla_\vmu {\color{purple}\mathcal{F}_t}$ and $\nabla_\vSigma {\color{purple}\mathcal{F}_t}$ terms.
Fortunately, we can express these in terms of the gradient and Hessian of the negative log-likelihood by invoking Bonnet’s and Price’s theorems (<a href="http://www0.cs.ucl.ac.uk/staff/c.archambeau/publ/neco_mo09_web.pdf">Opper & Archambeau, 2009</a>; <a href="https://arxiv.org/pdf/1401.4082.pdf">Rezende et al., 2014</a>):</p>
<p>\begin{align}
\label{eq:bonnet_gradient}
\nabla_\vmu {\color{purple}\mathcal{F}_t} &= \nabla_\vmu \myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]& &= \myexpect_{q_{\veta_t}(\vparam)} \left[\nabla_\vparam \log {\color{purple}p(\data\pipe\vparam)}\right]& &= -\myexpect_{q_{\veta_t}(\vparam)} \left[N{\color{purple}\vg(\vparam)} \right], \newline
\label{eq:bonnet_hessian}
\nabla_\vSigma {\color{purple}\mathcal{F}_t} &= \nabla_\vSigma \myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]& &= \frac{1}{2}\myexpect_{q_{\veta_t}(\vparam)} \left[\nabla^2_{\vparam\vparam} \log {\color{purple}p(\data\pipe\vparam)}\right]& &= -\frac{1}{2}\myexpect_{q_{\veta_t}(\vparam)} \left[N{\color{purple}\vH(\vparam)} \right],
\end{align}</p>
<p>where we have used the average per-example gradient ${\color{purple}\vg(\vparam)}$ and Hessian ${\color{purple}\vH(\vparam)}$ of the negative log-likelihood (the dataset has $N$ examples).</p>
<p>One final step. Until now we have been exact in our derivations (given a VI objective and Gaussian approximating family).
But we now need to make our first approximation to estimate Equations \eqref{eq:bonnet_gradient} and \eqref{eq:bonnet_hessian}: we use a Monte-Carlo sample $\vparam_t \sim q_{\veta_t}(\vparam) = \gauss(\vparam; \vmu_t, \vSigma_t)$ to approximate the expectation terms.
We expect any approximation error to reduce as we increase the number of Monte-Carlo samples.</p>
<p>This leads to an algorithm called Variational Online-Newton (VON) in <a href="https://arxiv.org/pdf/1806.04854.pdf">Khan et al. (2018)</a>,</p>
<p>\begin{align}
\label{eq:VON_Sigma}
\hspace{1em}\vSigma_{t+1}^{-1} &= (1-\beta_t)\vSigma_t^{-1} + \beta_t (N {\color{purple}\vH(\vparam_t)} + {\color{blue}\delta\vI}) \newline
\label{eq:VON_mu}
\vmu_{t+1} &= \vmu_t - \beta_t \vSigma_{t+1} (N{\color{purple}\vg(\vparam_t)} + {\color{blue}\delta}\vmu_t).
\end{align}</p>
<p>We can run this algorithm on models where we can calculate the gradient and Hessian (such as by using automatic differentiation). But calculating Hessians of (non-toy) neural networks is still difficult. We therefore have to approximate the Hessian ${\color{purple}\vH(\vparam_t)}$ in some way. This is done in the algorithm VOGN.</p>
<h3 id="variational-online-gauss-newton-vogn">Variational Online Gauss-Newton (VOGN)</h3>
<p>The Gauss-Newton matrix (<a href="https://arxiv.org/pdf/1412.1193.pdf">Martens, 2014</a>; <a href="https://www.cs.toronto.edu/~graves/nips_2011.pdf">Graves, 2011</a>; <a href="https://nic.schraudolph.org/pubs/Schraudolph02.pdf">Schraudolph, 2002</a>) approximates the Hessian with first order information, ${\color{purple}\vH(\vparam_t)} = -\nabla^2_{\vparam\vparam} \log {\color{purple}p(\data \pipe \vparam)} \approx \frac{1}{N} \sum_{i\in\data} {\color{purple}\vg_i(\vparam_t) \vg_i(\vparam_t)^\top} $.
It has some nice properties such as being positive semi-definite (which we require), making it a suitable choice.
We expect it to become a better approximation to the Hessian as we train for longer.
As we will see in the second part of this blog, it can also be calculated relatively quickly.
Please see the references for more details on the benefits and disadvantages of using the Gauss-Newton matrix, such as how it is connected to the (empirical) Fisher Information Matrix.</p>
<p>This Gauss-Newton approximation is the key approximation to go from VON to VOGN.
But we also make some other approximations to allow for good scaling to large datasets/architectures.
Here is a full list of changes to go from VON to VOGN with some comments:</p>
<ol>
<li>Use a stochastic minibatch $\mathcal{M}_t$ of size $M$ instead of all the data at every iteration. The per-example gradients in this mini-batch are ${\color{purple}\vg_i(\vparam_t)}$ and the average gradient is ${\color{purple}\hat{\vg}(\vparam_t)} = \frac{1}{M}\sum_{i\in\mathcal{M}_t} {\color{purple}\vg_i(\vparam_t)}$.
<ul>
<li>Using stochastic minibatches is crucial to scale algorithms to large datasets, and of course is common practice.</li>
</ul>
</li>
<li>Re-parameterise the update equations, $\mathbf{S}_t = (\vSigma_t^{-1} - {\color{blue}\delta\vI}) / N$.
<ul>
<li>This makes the equations simpler.</li>
</ul>
</li>
<li>Use a mean-field approximating family instead of a full-covariance Gaussian: $\mathbf{S}_t = \text{diag} (\vs_t)$.
<ul>
<li>This drastically reduces the number of parameters and is a common approximation employed in variational Bayesian neural networks.
But we do not have to stick to this. SLANG (<a href="https://arxiv.org/pdf/1811.04504.pdf">Mishkin et al., 2018</a>) uses a low-rank + diagonal covariance structure.
In the second part of this blog, we will see a K-FAC approximation (<a href="https://arxiv.org/pdf/1712.02390.pdf">Zhang et al., 2018</a>).</li>
</ul>
</li>
<li>Use the Gauss-Newton approximation to the (diagonal) Hessian: ${\color{purple}\vH(\vparam_t)} \approx \frac{1}{M} \sum_{i\in\mathcal{M}_t} \left( {\color{purple}\vg_i(\vparam_t)}^2 \right)$.
<ul>
<li>We have calculated this on a minibatch of data, and simplified the calculation to be element-wise squaring as we are using a diagonal approximation.</li>
</ul>
</li>
<li>Use separate learning rates ${\alpha_t, \beta_t}$ in the update equations for ${\vmu_t, \vs_t}$.
<ul>
<li>Strictly, the two learning rates should be the same. But the learning rates do not affect the fixed points of the algorithm (although they may affect which local minimum the algorithm converges to!). By introducing another hyperparameter, we hope for quicker convergence. As we shall see in the next blog post, this additional learning rate is usually not difficult to tune.</li>
</ul>
</li>
</ol>
<p>These changes lead to our final VOGN algorithm, ready for running on neural networks,</p>
<p>\begin{align}
\label{eq:VOGN_mu}
\vmu_{t+1} &= \vmu_t - \alpha_t \frac{ {\color{purple}\hat{\vg}(\vparam_t)} + {\color{blue}\tilde{\delta}}\vmu_t}{\vs_{t+1} + {\color{blue}\tilde{\delta}}}, \newline
\label{eq:VOGN_Sigma}
\vs_{t+1} &= (1-\beta_t)\vs_t + \beta_t \frac{1}{M} \sum_{i\in\mathcal{M}_t}\left( {\color{purple}\vg_i(\vparam_t)}^2 \right),
\end{align}</p>
<p>where ${\color{blue}\tilde{\delta}} = {\color{blue}\delta}/N$, and all operations are element-wise.</p>
<p>So how does this perform in practice? We explore this in detail in the second part of the blog.
For now, we borrow Figure 1(b) from <a href="https://arxiv.org/pdf/1807.04489.pdf">Khan & Nielsen (2018)</a>, which shows how Natural-Gradient VI (VOGN) can converge much quicker than Gradient VI (implemented as Bayes-By-Backprop (<a href="https://arxiv.org/pdf/1505.05424.pdf">Blundell et al., 2015</a>)) on two relatively small datasets.</p>
<div class="image-container">
<img src="/blog/assets/images/ngvi-bnns/comparison.png" alt="VOGN can converge quickly." id="figure-VOGNvsBBB" style="width: 100%; max-width: 700px" />
<p class="caption">
Figure 1: VOGN can converge quickly.
</p>
</div>
<h2 id="were-done">We’re done!</h2>
<p>I hope you now understand NGVI for BNNs a little better. You have seen how the equations are derived, and hopefully have more of a feel for why and when they might work. There was a lot of detailed maths, but I have tried to provide some intuition and make all our approximations clear.</p>
<p>In this first part, we stopped at VOGN on small neural networks. In the <a href="https://mlg-blog.com/2021/11/24/ngvi-bnns-part-2.html">second part</a>, we will compare VOGN with stochastic gradient-based algorithms such as SGD and Adam to provide some further intuition. We will take some inspiration from Adam to scale VOGN to much larger datasets/architectures, such as ImageNet/ResNets. The next blog post will be a lot less mathematical!</p>
<p>If you would like to cite this blog post, you can use the following bibtex:</p>
<blockquote>
<p>@misc{swaroop_2021, title={Natural-Gradient Variational Inference}, url={https://mlg-blog.com/2021/04/13/ngvi-bnns-part-1.html}, author={Swaroop, Siddharth}, year={2021}}</p>
</blockquote>Siddharth SwaroopWhat does it mean to combine variational inference with natural gradients? Can this scale to neural networks? What kind of approximations do we need to make? We take a detailed look at the mathematical derivations of such algorithms.What Keeps a Bayesian Awake At Night? Part 1: Day Time2021-03-31T00:00:00+00:002021-03-31T00:00:00+00:00https://mlg.eng.cam.ac.uk/blog/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-1<blockquote>
<p><em>The theory of probabilities is at bottom nothing but common sense reduced to calculus;</em>
<em>it enables us to appreciate with exactness that which accurate minds feel with a sort of instinct for which ofttimes they are unable to account.</em></p>
<p>— <em>Pierre-Simon Laplace (1749–1827)</em></p>
</blockquote>
<p>The Cambridge Machine Learning Group is renowned for having drunk the Bayesian Kool-Aid. We evangelise about probabilistic approaches in our teaching and research as a principled and unifying view of machine learning and statistics. In this context, I (Rich) have found it striking that many of the most influential contributors to this brand of machine learning and statistics are more circumspect. From Michael Jordan remarking that “this place is far too Bayesian”,<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> to Geoff Hinton saying that, having listed the key properties of the problems he’s interested in, Bayesian approaches just don’t cut it. Even Zoubin Ghahramani confides that aspects of the Bayesian approach “keep him awake at night”.<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup></p>
<p>So, as we kick-off this new blog, we thought we’d dig into these concerns and attempt to burst the Bayesian bubble. In order to do this, we’ve written two posts. In this first part, we present some of the standard sunny arguments for probabilistic inference and decision making. In the second part, we’ll shoot some of these arguments down and face the demons lurking in the night.</p>
<h2 id="what-is-the-probabilistic-approach-to-inference-and-decision-making">What is the probabilistic approach to inference and decision making?</h2>
<p>The goal of an inference problem is to estimate an unknown quantity $X$ from known quantities $D$. Examples include inferring the mass of the Higgs boson ($X$) from collider data ($D$); estimating the prevalence of Covid 19 infections ($X$) from PCR test data ($D$); or reconstructing files ($X$) from corrupted versions stored on a damaged hard disk ($D$).</p>
<p>The probabilistic approach to solving such problems proceeds in three stages:</p>
<p><strong>Stage 1: probabilistic modelling.</strong> The first stage is called <em>probabilistic modelling</em> and involves designing a probabilistic recipe that describes how all the variables, known variables $D$ and unknown variables $X$, are assumed to be produced. The model specifies a joint distribution over all these variables $p(X, D)$ and samples from it should reflect typical settings of the variables that you might have expected to encounter before seeing any data.</p>
<p><strong>Stage 2: probabilistic inference.</strong> The second stage is called <em>probabilistic inference</em>. In this stage, the sum and product rules of probability are used to manipulate the joint distribution over all variables into the conditional distribution of the unknown variables given the known variables:</p>
<p>\begin{equation} \label{eq:posterior}
p( X \cond D ) = \frac{p( X, D)}{p( D )}.
\end{equation}</p>
<p>This distribution on the left hand side of this equation is known as the <em>posterior distribution</em>. It tells us how probable any setting of the unknown variables X is given the known variables $D$. In this way, it tells us not only what is the most likely setting of the unknown variables ($\hat X_{\text{MAP}} = \argmax_{X} p(X \cond D)$), but also our uncertainty about $X$ — it summarises our <em>belief</em> about $X$ after seeing $D$. Equation \eqref{eq:posterior} follows from the product rule of probability;
another application of the product rule leads to Bayes’ rule.</p>
<p>In the example of inferring the mass of the Higgs boson, the unknown X is a parameter.<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup> Inferring parameters using the sum and product rules to form the posterior distribution is called being <em>Bayesian</em>.</p>
<p><strong>Stage 3: Bayesian decision theory.</strong> In real-world problems, inferences are usually made to serve decision making. For example, inferences could inform the design of a new particle accelerator to pin down particle masses; could decide whether to implement another national lockdown; or could decide whether to prosecute someone for financial crimes based on data recovered from a hard drive.</p>
<p>The probabilistic approach supports decision making in a third stage which goes by the grand name of <em>Bayesian decision theory</em>. Here the user provides a loss function $L(X, \delta)$ which specifies how unhappy they would be if they take decision $\delta$ when the unknown variables take a value $X$. We can compute how unhappy they will be on average with any decision by averaging the loss function over all possible settings of the unobserved variables, weighted by how probable they are under the posterior distribution $p( X \cond D )$:</p>
<p>\begin{equation}
(\text{average unhappiness})(\delta) = \E_{p(X \cond D)}[L(X, \delta)].
\end{equation}</p>
<p>This quantity is called the <em>posterior expected loss</em><sup id="fnref:7" role="doc-noteref"><a href="#fn:7" class="footnote" rel="footnote">4</a></sup>. We now pick the decision that we expect to be least unhappy about by minimising our expected unhappiness with respect to our decision $\delta$, and we’re done!</p>
<p><strong>Summary.</strong> This framework proposes a cleanly separated sequential three-step procedure: first, articulate your assumptions about the data via a probabilistic model; second, compute the posterior distribution over unknown variables; and third, select the decision which minimises the average loss under the posterior.</p>
<h2 id="whats-the-formal-justification-for-the-probabilistic-approach-to-inference-and-decision-making">What’s the formal justification for the probabilistic approach to inference and decision making?</h2>
<p>Why does a Bayesian represent their beliefs with probabilities<sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">5</a></sup>, reason according to the sum and product rule, and select actions which minimise the posterior expected loss? We’ll now review the most common theoretical arguments.</p>
<p><strong>(1) de Finetti’s exchangeability theorem</strong> justifies the use of model parameters $\theta$, conditional distributions $p(D_n \cond \theta)$ over data given parameters (also called the <em>likelihood of parameters</em>), and critically <em>prior distributions over parameters</em> $p(\theta)$ when specifying probabilistic models. The theorem says that if you believe the order in which the data $D = (D_n)_{n=1}^N$ arrives is unimportant — $p(D)$ is invariant to the order of $(D_n)_{n=1}^N$, an idea called <em>exchangeability</em> — then there exists a random variable $\theta$ with associated prior distribution such that the data $D$ are i.i.d. given $\theta$ and your belief is recovered by marginalising over $\theta$:</p>
<p>\begin{equation}
p(D) = \int p(\theta) \prod_{n=1}^N p(D_n \cond \theta) \,\mathrm{d}\theta.
\end{equation}</p>
<p>The argument presupposes the use of probability distributions over data, but shows that this in combination with exchangeability entails the existence of parameters with associated prior and posterior distributions. This idea was important historically as there were schools of statistical thought that eschewed placing distributions over parameters, but were happy placing distributions over data.</p>
<p><strong>(2) Cox’s theorem and coherence.</strong> Cox’s theorem (<a href="https://aapt.scitation.org/doi/abs/10.1119/1.1990764">Cox, 1945</a>; <a href="https://bayes.wustl.edu/etj/prob/book.pdf">Jaynes, 2003</a>) justifies the use of a probabilistic model and application of the sum and the product rules to perform inference. The theorem starts out by listing a number of <em>desiderata</em> that any reasonable system of quantitative rules for inference should satisfy. One very important such desideratum is <em>consistency</em> or <em>coherence</em><sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">6</a></sup>: if there are multiple ways of arriving at an inference, they should all give the same answer. For example, updating our beliefs about an unknown variable X after observing data $D=(D_n)_{n=1}^N$ should give the same result as incrementally updating our beliefs about $X$ one data point $D_n$ at a time. Cox’s theorem is the conclusion that every system satisfying the desiderata must be probability theory. In particular, it identifies probability theory as the unique extension of propositional logic, where a proposition is either true or false, to varying degrees of plausibility.</p>
<p><strong>(3) The Dutch book argument</strong> (<a href="https://EconPapers.repec.org/RePEc:hay:hetcha:ramsey1926">Ramsey, 1926</a>; <a href="http://eudml.org/doc/212523">de Finetti, 1931</a>) is another argument for the optimality of modelling and inference, one which connects to decision making. The argument goes as follows: if you’re willing to take bets on propositions with certain odds (in a way, these odds represent your beliefs), then, unless these odds are consistent with probability theory, you’re willing to take a collection of bets that nets a sure loss (a <em>Dutch book</em>).</p>
<p><strong>(4) Savage’s theorem</strong> (<a href="https://doi.org/10.1002/nav.3800010316">Savage, 1945</a>) is used to argue that optimal decision making entails the three-step probabilistic approach. Like Cox’s theorem, it takes an axiomatic approach, but here the axioms relate to decisions rather than inferences. The axioms include the idea that a decision maker is characterised by the ability to rank all decisions in some order of preference. It then lists properties that any reasonable ranking should have.<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">7</a></sup> Savage’s theorem says that these properties entail that the decision maker’s ranking is consistent with them acting according to Bayesian decision theory (<a href="http://www.econ2.jhu.edu/people/Karni/savageseu.pdf">Karni, 2005</a>).</p>
<p>The above arguments justify large parts of the probabilistic approach to inference and decision making. In contrast, the arguments we turn to next are only concerned with specific properties. One way to view them is as <em>unit tests</em> that the Bayesian framework passes.</p>
<p>Before we turn to them, you may have noticed that the formulation of the probabilistic approach and the justifications made so far do not include a notion of the <em>true model</em> or <em>true parameters</em>; rather, all that matters is your own personal beliefs about the world, how you update them as data arrive, and how decisions are made. However, it is perfectly reasonable to ask questions like “How does the posterior over parameters behave when the data were generated using some true underlying parameter value?” or “If I have the right model and apply probabilistic inference, will my predictions be good?” The next two results step out of the Bayesian framework to answer these questions:</p>
<p><strong>(5) Doob’s consistency theorem</strong> (<a href="https://www.emis.de/journals/JEHPS/juin2009/Locker.pdf">Doob, 1949</a>) shows that Bayesian inference is <em>consistent</em>: very often, if the data were sampled from $p(D \cond X)$ for some true value for $X$, then, as the user collects more and more data, the posterior over the unknown variables $p(X \cond D)$ concentrates on this true value for $X$. This is a frequentist analysis of a Bayesian estimation procedure, and it shows that the two paradigms can live comfortably side by side: Bayesian methods provide estimation procedures; frequentist tools allow analysis of these procedures. This is one reason why Michael Jordan thought that any Bayesian-focussed research group worth its salt should be paying close attention to frequentist ideas.</p>
<p><strong>(6) Optimality of Bayesian predictions.</strong> In many applications, such as those typically encountered in machine learning, predicting future data points is of central focus. What guarantees do we have on the quality of such estimates arising from the probabilistic approach? Well, if we use the Kullback–Leibler (KL) divergence to measure the distance between the true density $p(X \cond \theta)$ over the unknown $X$ and any data-dependent estimate of this density $q(X \cond D)$, then, when averaged over potential settings of the parameters and associated observed data $p(\theta)p(D\cond\theta)$, the estimate $q^*$ that minimises the average divergence is the Bayesian posterior predictive (<a href="https://doi.org/10.1093/biomet/62.3.547">Aitchison, 1975</a>):</p>
<p>\begin{equation}
q^*(X \cond D) = \argmin_{q} \E_{p(\theta,D)}[ \operatorname{KL}( p(X \cond \theta) \| q(X \cond D)) ] = p(X \cond D).
\end{equation}</p>
<p>This argument tells us that the Bayesian predictions coming from the “right model” are KL-optimal on average. Interestingly, this result connects to recent ideas in meta-learning (<a href="https://arxiv.org/abs/1805.09921">Gordon <em>et al.</em>, 2019</a>).</p>
<p><strong>(7) Wald’s theorem (<a href="https://doi.org/10.1214/aoms/1177730030">Wald, 1949</a>)</strong> can be used to justify minimising an expected loss as a way of decision making. The theorem is concerned with <em>admissible</em> decision rules, which are rules that, for every other decision rule, achieve a better loss for at least <em>some</em> realisation of the unknown $X$. This condition is a very low bar: we’d hope that any reasonable decision rule would have this property. However, surprisingly, Wald’s theorem says that the <em>only</em> rules which are admissible are essentially those derived from minimising the expected loss under some distribution (<a href="https://doi.org/10.1214/aoms/1177730030">Wald, 1949</a>; <a href="https://doi.org/10.1007/b98854">Lehmann & Casella, 1998</a>).</p>
<h2 id="conclusion">Conclusion</h2>
<p>It is striking that a number of arguments based on a diversity of desirable properties — including coherence, optimal betting strategies, specifying sensible preferences over actions, frequentist guarantees like consistency and optimal predictive accuracy — all suggest that the probabilistic approach to inference is a reasonable one. But in the next post we’ll ask whether, in the dead of night, everything is as rosy as it seems in daylight.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>Actually he was referring to the Gatsby unit in 2004 — in many ways the mother of the Cambridge Machine Learning Group — and his comment was a fair one. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p>This was in the Approximate Inference Workshop at NeurIPS in 2017. <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:3" role="doc-endnote">
<p>Parameters are distinguished from variables by asking what happens as we see more data: variables get more numerous, parameters do not. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:7" role="doc-endnote">
<p>We previously incorrectly called the posterior expected loss the <em>Bayes risk</em>. Thanks to Corey Yanofsky for pointing out the mistake. <a href="#fnref:7" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:4" role="doc-endnote">
<p>That probabilities represent degrees of belief is only one interpretation of probability. For example, in <em>Probability, Statistics, and Truth</em> (<a href="https://store.doverpublications.com/0486242145.html">Von Mises, 1928</a>), von Mises argues that probability concerns limiting frequencies of repeating events. In this view and contrary to the Bayesian view, it is meaningless to talk about the probability of a one-off event: it is not possible to repeatedly sample that one-off event, which means that it doesn’t have a limiting frequency. <a href="#fnref:4" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:5" role="doc-endnote">
<p>Consistency also has another specific technical meaning in statistics, so we will use the term coherence in what follows. <a href="#fnref:5" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:6" role="doc-endnote">
<p>The axioms are reminiscent of those used in <a href="https://en.wikipedia.org/wiki/Arrow%27s_impossibility_theorem">Arrow’s impossibility theorem</a> concerning <em>fair</em> voting systems. <a href="#fnref:6" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Wessel BruinsmaThe theory of probabilities is at bottom nothing but common sense reduced to calculus; it enables us to appreciate with exactness that which accurate minds feel with a sort of instinct for which ofttimes they are unable to account. — Pierre-Simon Laplace (1749–1827)What Keeps a Bayesian Awake At Night? Part 2: Night Time2021-03-31T00:00:00+00:002021-03-31T00:00:00+00:00https://mlg.eng.cam.ac.uk/blog/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-2<blockquote>
<p><em>The theory of subjective probability describes ideally consistent behaviour and ought not, therefore, be taken too literally.</em></p>
<p>— <em>Leonard Jimmie Savage (1917–1971)</em></p>
</blockquote>
<p>In the <a href="/2021/03/31/what-keeps-a-bayesian-awake-at-night-part-1.html">first post</a> in this series, we laid out the standard arguments that we and many others have used to support the edifice which is Bayesian inference. In this post, we identify the weaknesses in these arguments that cause us to lose sleep at night.</p>
<p>We’ve split these weaknesses into three types: first, weaknesses in the standard mathematical justifications for the probabilistic approach; second, weaknesses arising in practice from the modelling stage; and, third, weaknesses arising from realising the inference stage due to computational constraints.</p>
<h2 id="weakness-1-standard-justifications-have-problems">Weakness 1: Standard justifications have problems</h2>
<p>We have seen that probability theory and Bayesian decision theory are usually justified in one of several ways, but do these standard justifications <em>really</em> stand up to scrutiny? Let’s go through each of the seven arguments in turn.</p>
<p><strong>(1) de Finetti’s exchangeability theorem</strong> presupposes a probability distribution over the data and shows this naturally leads to distributions over parameters. However, it does not justify the use of probability in the first place.</p>
<p><strong>(2) Cox’s theorem</strong> does justify the probabilistic approach, but making the argument watertight turns out to be far more delicate than the textbooks, say of <a href="https://bayes.wustl.edu/etj/prob/book.pdf">Jaynes (2003)</a> or <a href="https://www.microsoft.com/en-us/research/uploads/prod/2006/01/Bishop-Pattern-Recognition-and-Machine-Learning-2006.pdf">Bishop (2006)</a>, would have you believe. To make the theorem mathematically rigorous requires additional technical assumptions that muddy the clarity of the argument. <a href="https://doi.org/10.1017/CBO9780511526596">Paris (1994, p. 24)</a>: “[W]hen an attempt is made to fill in all the details some of the attractiveness of the original is lost.”<sup id="fnref:1" role="doc-noteref"><a href="#fn:1" class="footnote" rel="footnote">1</a></sup> Moreover, there remains disagreement about the desirability of several of Cox’s theorem’s other assumptions.<sup id="fnref:2" role="doc-noteref"><a href="#fn:2" class="footnote" rel="footnote">2</a></sup> Perhaps the most controversial assumption is that plausibilities are represented by real numbers. As a consequence, for <em>every two possible propositions</em>, one of the two propositions conclusively has higher (or equal) belief. (This is called <em>universal comparability</em>.) But what if we are truly ignorant about two matters, or have not yet formed an opinion? In that case, is it reasonable to require that the plausibilities assigned to the matters are necessarily comparable?<sup id="fnref:3" role="doc-noteref"><a href="#fn:3" class="footnote" rel="footnote">3</a></sup></p>
<p><strong>(3) The Dutch book argument</strong> tells us that any coherent actor should use probability to express their beliefs, but the argument leaves the door open as to how the actor should update their beliefs in light of new evidence (<a href="https://doi.org/10.1086/288169">Hacking, 1976</a>). This is because the standard Dutch book setup is <em>static</em>: it does not involve a step where beliefs are updated on the basis of new information. That Bayes’ rule should be used for this purpose requires additional assumptions. <em>Dynamic</em> alternatives of the Dutch book argument attempt to fix this flaw, but again the force of the argument is diminished and open to criticism (<a href="https://doi.org/10.1086/289350">Skyrms, 1987</a>).<sup id="fnref:4" role="doc-noteref"><a href="#fn:4" class="footnote" rel="footnote">4</a></sup></p>
<p><strong>(4) Savage’s theorem</strong> guarantees a nearly unique loss function (unique up to an affine transform) and a unique probability, which together form the expected loss. The troubling aspect of Savage’s theorem is that the constructed loss function depends <em>only</em> on the outcome of a decision, and Savage’s axioms imply that outcomes of decisions have a value which is independent of the state of the world (<a href="http://www.econ2.jhu.edu/people/Karni/savageseu.pdf">Karni, 2005</a>). As <a href="https://doi.org/10.1287/moor.24.1.8">Wakker & Zank (1998)</a> remark, this disentanglement can be undesirable. They give the example of health insurance, where the value of money depends on sickness or health. If the loss function is allowed to depend on other aspects of the world, then the probability constructed by the theorem is no longer unique (<a href="http://www.econ2.jhu.edu/people/Karni/savageseu.pdf">Karni, 2005</a>). For this reason, <a href="http://www.econ2.jhu.edu/people/Karni/savageseu.pdf">Karni (2005)</a> argues that the probability constructed by the theorem is arbitrary and thus cannot realistically represent the decision maker’s beliefs.</p>
<p><strong>(5) Doob’s consistency theorem</strong>, the <strong>(6) optimality of Bayesian predictions</strong> and <strong>(7) Wald’s theorem</strong> are all only unit tests: how reassured should we really be that Bayesian inference passes them? It is not clear that the guarantees on the optimality of Bayesian predictions are the guarantees you care about. For example, typically we’re faced with a single dataset and care only about performance on it alone, rather than the average performance across many potential datasets. Similarly, the admissibility of a decision rule in Wald’s theorem is desirable but not sufficient: indeed, there are admissible estimators that are not reasonable.<sup id="fnref:5" role="doc-noteref"><a href="#fn:5" class="footnote" rel="footnote">5</a></sup> Doob’s consistency theorem also suffers from subtle theoretical issues (<a href="https://doi.org/10.1214/aos/1176349830">Diaconis & Freedman, 1986</a>), which we discuss below in the context of model mismatch.</p>
<p><strong>Take away.</strong> It is clear then, that uniquely and precisely justifying the three stage Bayesian approach via a single argument is much more delicate than many would have you believe. However, these issues don’t trouble our sleep. That is partly due to the fact that it is reassuring that so many and so diverse a set of arguments suggest that it is a sensible approach.<sup id="fnref:6" role="doc-noteref"><a href="#fn:6" class="footnote" rel="footnote">6</a></sup> However, it is also because there are far bigger issues to worry about.</p>
<h2 id="weakness-2-modelling-is-hard-and-inferences-are-sensitive-to-innocuous-details">Weakness 2: Modelling is hard and inferences are sensitive to innocuous details</h2>
<p>All practitioners of Bayesian inference will know well that beliefs are typically only roughly encoded into probabilistic models. There are at least three good reasons for this. First, it is often hard to really pin down precisely what you believe: What tail behaviour should a variable have? Are there latent variables at play? <em>Et cetera.</em> Second, even when you do have an ideal model in mind, it can be mathematically challenging to accurately translate that into a probability distribution. Third, many modelling choices are often based on convenience such as mathematical tractability.</p>
<p>Should we be worried about this? Surely roughly encoding our prior knowledge is sufficient?</p>
<p>Unfortunately, seemingly small or irrelevant inaccuracies in the model can greatly affect the posterior and therefore downstream decision making. For example, <a href="https://doi.org/10.1214/aos/1176349830">Diaconis & Freedman (1986)</a> show that “in high-dimensional problems, arbitrary details of the prior really matter”. Similarly, <a href="https://doi.org/10.2307/2291091">Kass & Raftery (1995)</a> conclude that, in the context of Bayesian model comparison, “[t]he chief limits of Bayes factors are their sensitivity to the assumptions in the parametric model and the choice of priors.” Indeed, for this reason, the textbook presentation of model comparison, such as that in <a href="http://www.inference.org.uk/mackay/itila/">David MacKay’s excellent text book</a>, should be seen as only a pedagogical depiction of Bayesian inference at work, rather than an approach that will bear practical fruit: discrete Bayesian model comparison does not work in practice.<sup id="fnref:7" role="doc-noteref"><a href="#fn:7" class="footnote" rel="footnote">7</a></sup> As a more recent example of the importance of priors in high-dimensional settings, experiments by <a href="https://arxiv.org/abs/2002.02405">Wenzel <em>et al.</em> (2020)</a> suggest that the usual choice of Gaussian prior<sup id="fnref:17" role="doc-noteref"><a href="#fn:17" class="footnote" rel="footnote">8</a></sup> for Bayesian neural network models contributes to the <em>cold posterior effect</em> in which the Bayesian posterior is outperformed on prediction tasks by strongly tempered versions.</p>
<p>Can theory about the performance of the Bayesian approach in the face of model misspecification act as a comfort blanket? There are theorems, analogous to Doob’s consistency theorem in the well-specified case, that describe situations when the Bayesian posterior will still concentrate on the true parameter value even when the model is misspecified (<a href="https://arxiv.org/abs/math/0607023">Kleijn & van der Vaart, 2006</a>; <a href="https://www.jstor.org/stable/24310519">De Blasi & Walker, 2013</a>; <a href="https://arxiv.org/abs/1312.4620">Ramamoorthi <em>et al.</em>, 2015</a>), but, as <a href="https://arxiv.org/abs/1412.3730">Grünwald & van Ommen (2017)</a> point out<sup id="fnref:8" role="doc-noteref"><a href="#fn:8" class="footnote" rel="footnote">9</a></sup>, “[these theorems hold] under regularity conditions that are substantially stronger than those needed for consistency when the model is correct.”<sup id="fnref:9" role="doc-noteref"><a href="#fn:9" class="footnote" rel="footnote">10</a></sup></p>
<p>Theoretical understanding of the consequences of model misspecification is not yet available to rescue us. The probabilistic approach to modelling therefore poses an unsettling dichotomy: on the one hand, you’re free to choose the prior, because it is <em>your belief</em>, <em>your prior</em>; but on the other hand, you should choose your prior absolutely right, because seemingly small or irrelevant changes can greatly affect the conclusions.</p>
<p><strong>A way forward?</strong> This weakness of the Bayesian approach used to keep us awake at night. However, taking a slightly different perspective cures our insomnia. The conventional view dogmatically insists that Bayesians should initially and immutably encapsulate prior assumptions up front in the modelling step before seeing data. They then perform inference and then decision making when the data arrives and at that point they’re done. The alternative perspective instead uses the three stage process as a <em>consistency checking device</em>: if I made these assumptions, then these would be the corresponding coherent statistical inferences. In this way, we are free to explore a range of different assumptions, assessing their consequences both for the model and inferences we make. We are free to modify the model accordingly until we arrive at something that well approximates what we believe. This view has been called the <em>hypothetico–deductive</em> view of Bayesian inference (<a href="https://arxiv.org/abs/1006.3868">Gelman & Shalizi, 2011</a>) and it has the advantage that this is the way many of us use these methods in practice. The disadvantage is that double-dipping is possible, which requires that you are careful about the checks and diagnostics that you perform on the model and corresponding inferences.</p>
<h2 id="weakness-3-approximate-inference">Weakness 3: Approximate inference</h2>
<p>The weakness that kept Zoubin Ghahramani (and many other Bayesians) awake at night is that the Bayesian posterior is computationally intractable in all but the simplest cases. Consequently, approximations are necessary, which means that all the theoretical guarantees and justifications that are true for the exact Bayesian posterior no longer hold. Worse still, we do not know in what ways approximation will affect different aspects of the solution.</p>
<p>For example, one common approximation technique is <em>variational inference</em> (<a href="http://dx.doi.org/10.1561/2200000001">Wainwright & Jordan, 2008</a>), which side-steps intractable computations arising from application of the sum and product rules by projecting distributions onto simpler tractable alternatives. Variational inference is therefore <strong>guaranteed to be incoherent</strong>: it returns different solutions from exactly applying the sum and product rules. How do these approximate inferences differ from the true ones?</p>
<p>Variational inference is known to (1) underestimate posterior uncertainty if factorised approximations are used<sup id="fnref:10" role="doc-noteref"><a href="#fn:10" class="footnote" rel="footnote">11</a></sup> and (2) bias parameter learning so that overly simple models are returned that underfit the data (<a href="http://www.gatsby.ucl.ac.uk/~turner/Publications/turner-and-sahani-2011a.html">Turner & Sahani, 2011</a>). An example of the first phenomenon is the observation that the mean-field variational approximation in neural networks is unable to model in-between uncertainty (<a href="https://arxiv.org/abs/1906.11537">Foong <em>et al.</em>, 2019</a>). Examples of the second phenomenon are over-pruning in mixture models (<a href="http://www.inference.org.uk/mackay/minima.pdf">MacKay, 2001</a>; <a href="https://doi.org/10.1214/06-BA104">Blei & Jordan, 2006</a>), variational autoencoders<sup id="fnref:11" role="doc-noteref"><a href="#fn:11" class="footnote" rel="footnote">12</a></sup> (<a href="https://arxiv.org/abs/1509.00519">Burda <em>et al.</em>, 2015</a>; <a href="https://arxiv.org/abs/1611.02731">Chen <em>et al.</em>, 2016</a>; <a href="https://arxiv.org/abs/1706.02262">Zhao <em>et al.</em>, 2017</a>; <a href="https://arxiv.org/abs/1706.03643">Yeung <em>et al.</em>, 2017</a>), and Bayesian neural networks (<a href="https://arxiv.org/abs/1801.06230">Trippe & Turner, 2018</a>).</p>
<p>Inevitably, these errors are amplified if repeated approximation steps are required. For example, in <em>online</em> or <em>continual learning</em> the goal is to incorporate new observations sequentially without forgetting old ones (<a href="https://arxiv.org/abs/1710.10628">Nguyen <em>et al.</em>, 2017</a>). In theory, repeated application of Bayes’ rule provides the optimal solution, but in practice approximations like variational inference are required at each step. This causes amplification of the approximation error, which eventually leads the model to “forget” old data. Consequently, Bayesian methods are approximate and have to be checked and benchmarked for catastrophic forgetting, just like their non-Bayesian counterparts.</p>
<p>We lack a general theory that justifies variational inference. Recent work has shown that it is frequentist consistent (<a href="https://arxiv.org/abs/1705.03439">Wang & Blei, 2017</a>) under assumptions similar to those required for consistency of Bayesian inference under model misspecification. Although useful, this is a long way short of a general compelling justification for the approach.<sup id="fnref:12" role="doc-noteref"><a href="#fn:12" class="footnote" rel="footnote">13</a></sup></p>
<p>Similar arguments can be made against other approaches to approximate inference such as Monte Carlo methods. For example, <em>Markov chain Monte Carlo</em> (MCMC) is guaranteed to eventually give you a close-enough answer; but it may take an unfeasibly long time to do so and diagnosing whether you have reached that point is very difficult (we often rely on heuristics to check whether the chain has been run for long enough). So, although excellent software for performing MCMC exists,<sup id="fnref:13" role="doc-noteref"><a href="#fn:13" class="footnote" rel="footnote">14</a></sup> ensuring that it is performing correctly is delicate.<sup id="fnref:14" role="doc-noteref"><a href="#fn:14" class="footnote" rel="footnote">15</a></sup></p>
<p>Let us take a step back and consider approximate inference in the context of Bayesian decision theory. In its ideal form, Bayesian decision theory is a cleanly separated three-step procedure consisting of model specification, inference and loss minimisation. However, in practice the use of approximate inference entangles the three steps: models which are simpler result in more accurate, and therefore more coherent, inference; and downstream decision making dictates what aspects of the true posterior your approximate posterior should capture and therefore feeds into the inference stage rather than being decoupled from it (<a href="http://proceedings.mlr.press/v15/lacoste_julien11a.html">Lacoste-Julien <em>et al.</em>, 2011</a>).<sup id="fnref:15" role="doc-noteref"><a href="#fn:15" class="footnote" rel="footnote">16</a></sup></p>
<p><strong>A way forward?</strong> To our minds, the jury is still out here about the best way forward. Hopefully a more general theory — potentially based on decision making under uncertainty and computational constraints that extends Cox’s, Ramsey’s and de Finetti’s, or Savage’s ideas — will emerge. In the meantime, continuing to provide more limited theoretical characterisation of the properties of existing inference approaches is vital. Practitioners should also test the hell out of their inference schemes to gain confidence in them. Testing on special cases where ground truth posteriors or underlying variables are known is helpful. The acid test is whether your inference scheme works on the real world data you care about, so test cases also need to replicate aspects of this situation. Here the ideas of probabilistic models and being able to sample fake data, potentially from models trained on real data, is very useful. This fits nicely with the hypothetico–deductive loop: specify your model, perform approximate inference, check your inferences, now check your model, review your modelling choices, and start the process again.</p>
<p>Of course the amazing and diverse practical successes of Bayesian inference — from <a href="http://mathcenter.oxford.emory.edu/site/math117/bayesTheorem/enigma_and_bayes_theorem.pdf">cracking Enigma</a> and <a href="https://projecteuclid.org/journals/statistical-science/volume-29/issue-1/Search-for-the-Wreckage-of-Air-France-Flight-AF-447/10.1214/13-STS420.full">finding Air France Flight 447</a>, to <a href="https://en.wikipedia.org/wiki/TrueSkill">the TrueSkill match-making system</a> — give us great confidence in the utility of the probabilistic approach.</p>
<h2 id="conclusion">Conclusion</h2>
<p>We started the first of these two posts by remarking that it was striking that Michael Jordan, Geoff Hinton, and Zoubin Ghahramani — individuals who had made huge contributions to the probabilistic brand of machine learning — maintained reservations about that very approach. We hope that this post has brought colour to these concerns.</p>
<p>Michael Jordan thought that Bayesians and frequentists should reconcile and operate arm-in-arm. This makes sense since Bayesians develop inference procedures and frequentist methods can be used to evaluate them.</p>
<p>Geoff Hinton’s view was that Bayesian approaches don’t cut it for the problems he is interested in, like object recognition and machine translation. This is understandable since we have seen that Bayesian methods are useful when (1) you’re able to encode knowledge carefully into a detailed probabilistic model; (2) you are uncertain and want to propagate uncertainty accurately; and (3) you can perform accurate approximate inference. Is this really true in situations like object recognition or machine translation? These problems involve learning complex high-dimensional representations from stacks of largely noise-free images, words, or speech, and, to quote Zoubin Ghahramani, “if our goal is to learn a representation (…) then Bayesian inference gets in the way”.<sup id="fnref:16" role="doc-noteref"><a href="#fn:16" class="footnote" rel="footnote">17</a></sup></p>
<p>Zoubin Ghahramani admitted that approximate inference kept him awake at night. We have seen that the arguments that we make to justify the Bayesian approach — Cox’s and Savage’s theorems, Dutch books, <em>et cetera</em> — are practically meaningless because our inference is approximate. The elegant separation of modelling, inference, and decision making steps is, in fact, a tangled web.</p>
<p>So, when we teach the ideas of probabilistic modelling and inference and write boilerplate in our papers, we’d ask that you pause for breath before trotting off the standard justifications. Instead, consider evangelising a little less about how principled the approach is, mention the importance of exploring different modelling assumptions and their consequences, and stress that the quality of approximate inference should be evaluated in a systematic and principled way.</p>
<div class="footnotes" role="doc-endnotes">
<ol>
<li id="fn:1" role="doc-endnote">
<p>Also quoted by <a href="https://arxiv.org/abs/1507.06597">Terenin & Draper (2015)</a>. <a href="https://arxiv.org/abs/1105.5450">Halpern (1999)</a> constructs a counterexample to Cox’s theorem if the additional technical assumption, <a href="https://doi.org/10.1017/CBO9780511526596">Paris’ (1994)</a> density assumption, is omitted. <a href="#fnref:1" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:2" role="doc-endnote">
<p><a href="https://doi.org/10.1016/S0888-613X(03)00051-3">Van Horn (2003)</a> concludes that “[they] cannot make a compelling case for all of [Cox’s theorem’s] requirements, however, and there remains disagreement as to the desirability of several of them.” <a href="#fnref:2" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:3" role="doc-endnote">
<p>Interestingly, there are two-dimensional theories of probability; see Section 4.1.2 by <a href="https://scripties.uba.uva.nl/download?fid=639254">Bakker (2016)</a> for an overview. <a href="#fnref:3" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:4" role="doc-endnote">
<p><a href="https://doi.org/10.1086/289350">Skyrms (1987, p. 4)</a> states a dynamic Dutch book argument due to David Lewis (reported by Paul Teller; 1973, 1976). Skyrms also says (p. 19) that “(…) not every learning situation is of the kind to which conditionalization applies. The situation may not be of the kind that can be correctly described or even usefully approximated as the attainment of certainty by the agent in some proposition in his degree of belief space. The rule of belief change by probability kinematics on a partition was designed to apply to a much broader class of learning situations than the rule of conditionalization.” In the paper, Skyrms describes the Observation Game, a learning situation where conditionalisation does not apply, where a generalisation of conditionalisation called <em>probability kinematic</em> (<a href="https://press.uchicago.edu/ucp/books/book/chicago/L/bo3640589.html">Jeffrey, 1965</a>) — essentially, agreement with Bayes’ rule on only a given partition of the probability space — is necessary and, under certain conditions, sufficient to be <em>bulletproof</em> — a strong coherence condition that excludes a Dutch book. <a href="#fnref:4" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:5" role="doc-endnote">
<p>For example, see Example 5.7.2 by <a href="https://doi.org/10.1007/b98854">Lehmann & Casella</a> (<a href="https://doi.org/10.1007/b98854">1998</a>; also <a href="https://doi.org/10.1214/aos/1176343853">Makani, 1997</a>). <a href="#fnref:5" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:6" role="doc-endnote">
<p><a href="https://doi.org/10.2307/2281641">Savage’s (1962)</a> comment on personal probability that was reproduced at the start of this post captures this sentiment well. <a href="#fnref:6" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:7" role="doc-endnote">
<p>Andrew Gelman and David MacKay discuss this issue <a href="https://statmodeling.stat.columbia.edu/2011/12/04/david-mackay-and-occams-razor/">here</a> and <a href="https://statmodeling.stat.columbia.edu/2011/06/09/difficulties_wi/">here</a>. <a href="#fnref:7" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:17" role="doc-endnote">
<p>We previously wrote “the usual choice of <em>a</em> Gaussian prior”, which, unfortunately, is slightly ambiguous. We mean to refer to the usual choices of $\mathbf{w} \sim \mathcal{N}(0, 1)$ or $\mathbf{w} \sim \mathcal{N}(0, 1/\sqrt{n})$. <a href="#fnref:17" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:8" role="doc-endnote">
<p>See also <a href="https://stats.stackexchange.com/a/279798">this great answer</a> by Peter Grünwald on StackExchange. <a href="#fnref:8" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:9" role="doc-endnote">
<p>As a solution, <a href="https://arxiv.org/abs/1412.3730">Grünwald & van Ommen (2017)</a> propose to replace the posterior by a generalised posterior, which depends on a learning rate. This, however, requires an appropriate choice of the learning rate, obscures the meaning of the likelihood, and, more importantly, invalidates the fundamental justifications which hold for Bayes’ rule. <a href="#fnref:9" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:10" role="doc-endnote">
<p>Often the approximations involve some form of factorisation assumption, but alternatives exist that have different properties. For example, inducing point approximations for Gaussian processes (<a href="http://proceedings.mlr.press/v5/titsias09a.html">Titsias, 2009</a>) tend to overestimate uncertainty. <a href="#fnref:10" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:11" role="doc-endnote">
<p>As an attempt to improve the variational approximation, it has been proposed to recalibrate the variational objective by reweighting the KL term (<a href="https://arxiv.org/abs/1612.00410">Alemi <em>et al.</em>, 2017</a>; <a href="https://openreview.net/forum?id=Sy2fzU9gl">Higgings <em>et al.</em>, 2017</a>), which is closely related to the cold posterior effect. But by attempting to fix variational inference in this way, we are blurring the lines between Bayesian modelling and end-to-end optimisation, producing cleverly regularised estimators rather than models which reason according to fundamental principles. <a href="#fnref:11" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:12" role="doc-endnote">
<p>A pragmatic solution identifies standard practice, which uses point estimates for the unknowns such as $\hat X_{\text{MAP}}$, as effectively summarising the posterior $p(X \cond D)$ by a Dirac delta function, a probability distribution concentrated on $\hat X_{\text{MAP}}$. <a href="http://www.inference.org.uk/mackay/itila/">David MacKay (2003, Section 33.6)</a> says that, “[f]rom this perspective, <em>any</em> approximating distribution $Q(x;\theta)$, no matter how crummy it is, <em>has</em> to be an improvement on the spike produced by the standard method!” However, “has” is doing a lot of work in this sentence. <a href="#fnref:12" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:13" role="doc-endnote">
<p>For example, <a href="https://mc-stan.org/">Stan</a>, <a href="https://docs.pymc.io/">PyMC3</a>, or <a href="https://turing.ml/">Turing.jl</a>, but there is <a href="https://en.wikipedia.org/wiki/Probabilistic_programming#List_of_probabilistic_programming_languages">more out there</a>. <a href="#fnref:13" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:14" role="doc-endnote">
<p>Micheal Betancourt has a <a href="https://betanalpha.github.io/assets/case_studies/markov_chain_monte_carlo.html#5_robust_application_of_markov_chain_monte_carlo_in_practice">great post</a> about responsible use of MCMC in practice. <a href="#fnref:14" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:15" role="doc-endnote">
<p>Indeed, the idea of decision-making-aware inference has been successful in the meta-learning setting (<a href="https://arxiv.org/abs/1805.09921">Gordon <em>et al.</em>, 2019</a>). It seems likely that the entanglement of inference and decision making (as a consequence of the intractability of the posterior) could justify the recent trend towards end-to-end systems, which directly optimise for the metric of interest, thereby circumventing these issues. <a href="#fnref:15" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
<li id="fn:16" role="doc-endnote">
<p>The development of “Bayesian-inspired” approaches, which blend end-to-end deep learning with ideas from probabilistic modelling and inference is arguably a reaction to these concerns. The goal of these developments is to provide excellent representation learning and strong predictive performance whilst still handling uncertainty, but without claiming strict adherence to the Bayesian dogma. <a href="#fnref:16" class="reversefootnote" role="doc-backlink">↩</a></p>
</li>
</ol>
</div>Wessel BruinsmaThe theory of subjective probability describes ideally consistent behaviour and ought not, therefore, be taken too literally. — Leonard Jimmie Savage (1917–1971)