Flow matching (FM) is a recent generative modelling paradigm which has rapidly been gaining popularity in the deep probabilistic ML community. Flow matching combines aspects from Continuous Normalising Flows (CNFs) and Diffusion Models (DMs), alleviating key issues both methods have. In this blogpost we’ll cover the main ideas and unique properties of FM models starting from the basics.
Let’s assume we have data samples $x_1, x_2, \ldots, x_n$ from a distribution of interest $q_1(x)$, whose density is unknown. We’re interested in using these samples to learn a probabilistic model approximating $q_1$. In particular, we want efficient generation of new samples (approximately ) distributed from $q_1$. This task is referred to as generative modelling.
The advancement in generative modelling methods over the past decade has been nothing short of revolutionary. In 2012, Restricted Boltzmann Machines, then the leading generative model, were just about able to generate MNIST digits. Today, state-of-the-art methods are capable of generating high-quality images, audio and language, as well as model complex biological and physical systems. Unsurprisingly, these methods are now venturing into video generation.
Flow Matching (FM) models are in nature most closely related to (Continuous) Normalising Flows (CNFs). Therefore, we start this blogpost by briefly recapping the core concepts behind CNFs. We then continue by discussing the difficulties of CNFs and how FM models address them.
Let $\phi: \mathbb{R}^d \rightarrow \mathbb{R}^d$ be a continuously differentiable function which transforms elements of $\mathbb{R}^d$, with a continously differentiable inverse $\phi^{-1}: \mathbb{R}^d \to \mathbb{R}^d$. Let $q_0(x)$ be a density on $\mathbb{R}^d$ and let $p_1(\cdot)$ be the density induced by the following sampling procedure
\[\begin{equation*} \begin{split} x &\sim q_0 \\ y &= \phi(x), \end{split} \end{equation*}\]which corresponds to transforming the samples of $q_0$ by the mapping $\phi$. Using the change-of-variable rule we can compute the density of $p_1$ as
\[\begin{align} \label{eq:changevar} p_1(y) &= q_0(\phi^{-1}(y)) \abs{\det\left[\frac{\partial \phi^{-1}}{\partial y}(y)\right]} \\ \label{eq:changevar-alt} &= \frac{q_0(x)}{\abs{\det\left[\frac{\partial \phi}{\partial x}(x)\right]}} \quad \text{with } x = \phi^{-1}(y) \end{align}\]where the last equality can be seen from the fact that $\phi \circ \phi^{-1} = \Id$ and a simple application of the chain rule^{1}. The quantity $\frac{\partial \phi^{-1}}{\partial y}$ is the Jacobian of the inverse map. It is a matrix of size $d\times d$ containing $J_{ij} = \frac{d\phi^{-1}_i}{dx_j}$. Depending on the task at hand, evaluation of likelihood or sampling, the formulation in $\eqref{eq:changevar}$ or $\eqref{eq:changevar-alt}$ is preferred (Friedman, 1987; Chen & Gopinath, 2000).
Suppose $\phi$ is a linear function of the form
\[\phi(x) = ax+b\]with scalar coefficients $a,b\in\mathbb{R}$, and $p$ to be Gaussian with mean $\mu$ and variance $\sigma^2$, i.e.
\[p = \mathcal{N}(\mu, \sigma^2).\]We know from linearity of Gaussians that the induced $q$ will also be Gaussian distribution but with mean $a\mu+b$ and variance $a^2 \sigma^2$, i.e.
\[q = \mathcal{N}(a \mu + b, a^2 \sigma^2).\]More interestingly, though, is to verify that we obtain the same solution by applying the change-of-variable formula. The inverse map is given by
\[\phi^{-1}(y) \mapsto \frac{y-b}{a}\]and it’s derivative w.r.t. $y$ is thus $1/a$ assuming scalar inputs. We thus obtain
\[\begin{align*} q(y) &= p\bigg(\frac{y-b}{a}\bigg) \frac{1}{a} \\ &= \mathcal{N}\bigg(\frac{y-b}{a}; \mu, \sigma^2\bigg) \frac{1}{a}\\ &= \frac{1}{\sqrt{2\pi\sigma^2}}\exp \bigg(-\frac{(y/a -b/a-\mu)^2}{2\sigma^2} \bigg)\frac{1}{a}\\ &= \frac{1}{\sqrt{2\pi(a\sigma)^2}}\exp \bigg(-\frac{1}{a^2}\frac{(y-(a\mu+b))^2}{2\sigma^2} \bigg) \\ &= \mathcal{N}\big(y; a\mu+b,a^2\sigma^2\big). \end{align*}\]We have thus verified that the change-of-variables formula can be used to compute the density of a Gaussian variable tranformed by a linear mapping.
Often, to simplify notation, we will use the ‘push-forward’ operator $[\phi]_{\#}$ to denote the change in density of applying an invertible map $\phi$ to an input density. That is
\[q(y) = ([\phi]_{\#} p)(y) = p\big(\phi^{-1}(y)\big) \det\left[\frac{\partial \phi^{-1}}{\partial y}(y)\right].\]If we make the choice of $a = 1$ and $b = \mu$, then we get $\mathcal{N}(\mu, 1)$, as can be seen in the figure below.
Transforming a base distribution $q_0$ into another $p_1$ via a transformation $\phi$ is interesting, yet its direct application in generative modelling is limited. In generative modelling, the aim is to approximate a distribution using only the available samples. Therefore, this task requires the transformation $\phi$ to map samples from a “simple” distribution, such as $\mathcal{N}(0,I)$, to approximately the data distribution. However, a straightforward linear transformation, as in the previous example, is inadequate due to the highly non-Gaussian nature of the data distribution. This brings us to a neural network as a flexible transformation $\phi_\theta$. The key task then becomes optimising the neural net’s parameters $\theta$.
Let’s denote the induced parametric density by the flow $\phi_\theta$ as $p_1 \triangleq [\phi_\theta]_{\#}p_0$.
A natural optimisation objective for learning the parameters $\theta \in \Theta$ is to consider maximising the probability of the data under the model:
\[\begin{equation*} \textrm{argmax}_{\theta}\ \ \mathbb{E}_{x\sim \mathcal{D}} [\log p_1(x)]. \end{equation*}\]Parameterising $\phi_\theta$ as a deep neural network leads to several constraints:
Designing flows $\phi$ therefore requires trading-off expressivity (of the flow and thus of the probabilistic model) with the above mentioned considerations so that the flow can be trained efficiently.
In particular, computing the determinant of the Jacobian is in general very expensive (as it would require $d$ automatic differentation passes in the flow) so we impose structure in $\phi$^{2}.
Full-rank residual (Behrmann et al., 2019; Chen et al., 2010)
Expressive flows relying on a residual connection have been proposed as an interesting middle-ground between expressivity and efficient determinant estimation. They take the form:
\[\begin{equation} \label{eq:full_rank_res} \phi_k(x) = x + \delta ~u_k(x), \end{equation}\]where unbiased estimate of the log likelihood can be obtained^{3}. As opposed to auto-regressive flows (Huang et al., 2018, Larochelle and Murray, 2011, Papamakarios et al., 2017), and low-rank residual normalising flows (Van Den Berg et al. 2018), the update in \eqref{eq:full_rank_res} has full rank Jacobian, typically leading to more expressive transformations.
We can also compose such flows to get a new flow:
\[\begin{equation*} \phi = \phi_K \circ \ldots \circ \phi_2 \circ \phi_1. \end{equation*}\]This can be a useful way to construct move expressive flows! The model’s log-likelihood is then given by summing each flow’s contribution
\[\begin{equation*} \log q(y) = \log p(\phi^{-1}(y)) + \sum_{k=1}^K \log \det\left[\frac{\partial \phi_k^{-1}}{\partial x_{k+1}}(x_{k+1})\right] \end{equation*}\]with $x_k = \phi_K^{-1} \circ \ldots \circ \phi^{-1}_{k} (y)$.
As mentioned previously, residual flows are transformations of the form $\phi(x) = x + \delta \ u(x)$ for some $\delta > 0$ and Lipschitz residual connection $u$. We can re-arrange this to get
\[\begin{equation*} \frac{\phi(x) - x}{\delta} = u(x) \end{equation*}\]which is looking awfully similar to $u$ being a derivative. In fact, letting $\delta = 1/K$ and taking the limit $K \rightarrow \infty$ under certain conditions^{4}, a composition of residual flows $\phi_K \circ \cdots \circ \phi_2 \circ \phi_1$ is given by an ordinary differential equation (ODE):
\[\begin{equation*} \frac{\dd x_t}{\dd t} = \lim_{\delta \rightarrow 0} \frac{x_{t+\delta} - x_t}{\delta} = \frac{\phi_t(x_t) - x_t}{\delta} = u_t(x_t) \end{equation*}\]where the flow of the ODE $\phi_t: [0,1]\times\mathbb{R}^d\rightarrow\mathbb{R}^d$ is defined such that
\[\begin{equation*} \frac{d\phi_t}{dt} = u_t(\phi_t(x_0)). \end{equation*}\]That is, $\phi_t$ maps initial condition $x_0$ to the ODE solution at time $t$:
\[\begin{equation*} x_t \triangleq \phi_t(x_0) = x_0 + \int_{0}^t u_s(x_s) \dd{s} . \end{equation*}\]Of course, this only defines the map $\phi_t(x)$; for this to be a useful normalising flow, we still need to compute the log-abs-determinant of the Jacobian!
As it turns out, the density induced by $\phi_t$ (or equivalently $u_t$) can be computed via the following equation^{5}
\[\begin{equation*} \frac{\partial}{\partial_t} p_t(x_t) = - (\nabla \cdot (u_t p_t))(x_t). \end{equation*}\]This statement on the time-evolution of $p_t$ is generally known as the Transport Equation. We refer to $p_t$ as the probability path induced by $u_t$.
Computing the total derivative (as $x_t$ also depends on $t$) in log-space yields^{6}
\[\begin{equation*} \frac{\dd}{\dd t} \log p_t(x_t) = - (\nabla \cdot u_t)(x_t) \end{equation*}\]resulting in the log density
\[\begin{equation*} \log p_t(x) = \log p_0(x_0) - \int_0^t (\nabla \cdot u_s)(x_s) \dd{s}. \end{equation*}\]Parameterising a vector field neural network $u_\theta: \mathbb{R}_+ \times \mathbb{R^d} \rightarrow \mathbb{R^d}$ therefore induces a parametric log-density
\[\log p_\theta(x) \triangleq \log p_1(x) = \log p_0(x_0) - \int_0^1 (\nabla \cdot u_\theta)(x_t) \dd t.\]In practice, to compute $\log p_t$ one can either solve both the time evolution of $x_t$ and its log density $\log p_t$ jointly
\[\begin{equation*} \frac{\dd}{\dd t} \Biggl( \begin{aligned} x_t \ \quad \\ \log p_t(x_t) \end{aligned} \Biggr) = \Biggl( \begin{aligned} u_\theta(t, x_t) \quad \\ - \div u_\theta(t, x_t) \end{aligned} \Biggr), \end{equation*}\]or solve only for $x_t$ and then use quadrature methods to estimate $\log p_t(x_t)$.
Feeding this (joint) vector field to an adaptive step-size ODE solver allows us to control both the error in the sample $x_t$ and the error in the $\log p_t(x)$.
One may legitimately wonder why should we bother with such time-continuous flows versus discrete residual flows. There are a couple of benefits:
Now that you know why CNFs are cool, let’s have a look at what such a flow would be for a simple example.
Let’s come back to our earlier example of mapping a 1D Gaussian to another one with different mean. In contrast to previously where we derived a ‘one-shot’ (i.e. discrete) flow bridging between the two Gaussians, we now aim to derive a time-continuous flow $\phi_t$ which would correspond to the time integrating a vector field $u_t$.
We have the following two distributions
\[\begin{equation*} p_0 = \mathcal{N}(0, 1) \quad \text{and} \quad p_1 = \mathcal{N}(\mu, 1). \end{equation*}\]It’s not difficult to see that we can continuously bridge between these with a simple linear transformation
\[\begin{equation*} \phi(t, x_0) = x_0 + \mu t \end{equation*}\]which is visualized in the figure below.
By linearity, we know that every marginal $p_t$ is a Gaussian, and so
\[\begin{equation*} \mathbb{E}_{p_0}[\phi_t(x_0)] = \mu t \end{equation*}\]which, in particular, implies that $\mathbb{E}_{p_0}[\phi_1(x_0)] = \mu = \mathbb{E}_{p_1}[x_1]$. Similarly, we have
\[\begin{equation*} \mathrm{Var}_{p_0}[\phi_t(x_0)] = 1 \quad \implies \quad \mathrm{Var}_{p_0}[\phi_1(x_0)] = 1 = \mathrm{Var}_{p_1}[x_1] \end{equation*}\]Hence we have a probability path $p_t = \mathcal{N}(\mu t, 1)$ bridging $p_0$ and $p_1$.
Now let’s determine what the vector field $u_t(x)$ would be in this case. As mentioned earlier, $u(t, x)$ should satisfy the following
\[\begin{equation*} \dv{\phi_t}{t}(x_0) = u_t \big( \phi_t(x_0) \big). \end{equation*}\]Since we have already specified $\phi$, we can plug it in on the left hand side to get
\[\begin{equation*} \dv{\phi_t}{t}(x_0) = \dv{t} \big( x_0 + \mu t \big) = \mu \end{equation*}\]which gives us
\[\begin{equation*} \mu = u_t \big( x_0 + \mu t \big). \end{equation*}\]The above needs to hold for all $t \in [0, 1]$, and so it’s not too difficult to see that one such solution is the constant vector field
\[\begin{equation*} u_t(x) = \mu. \end{equation*}\]We could of course have gone the other way, i.e. define the $u_t$ such that $p_0 \overset{u_t}{\longleftrightarrow} p_1$ and derive the corresponding $\phi_t$ by solving the ODE.
Similarly to any flows, CNFs can be trained by maximum log-likelihood
\[\mathcal{L}(\theta) = \mathbb{E}_{x\sim q_1} [\log p_1(x)],\]where the expectation is taken over the data distribution and $p_1$ is the parameteric distribution. This involves integrating the time-evolution of samples $x_t$ and log-likelihood $\log p_t$, both terms being a function of the parametric vector field $u_{\theta}(t, x)$. This requires
CNFs are very expressive as they parametrise a large class of flows, and therefore of probability distributions. Yet training can be extremely slow due to the ODE integration at each iteration. One may wonder whether a ‘simulation-free’, i.e. not requiring any integration, training procedure exists for training these CNFs.
And that is exactly where Flow Matching (FM) comes in!
Flow matching is a simulation-free way to train CNF models where we directly formulate a regression objective w.r.t. the parametric vector field $u_\theta$ of the form
\[\begin{equation*} \mathcal{L}(\theta)_{} = \mathbb{E}_{t \sim \mathcal{U}[0, 1]} \mathbb{E}_{x \sim p_t}\left[\| u_\theta(t, x) - u(t, x) \|^2 \right]. \end{equation*}\]In the equation above, $u(t, x)$ would be a vector field inducing a probability path (or bridge) $p_t$ interpolating the reference $p_0$ to $p_1$, i.e.
\[\begin{equation*} \log p_1(x) = \log p_0 - \int_0^1 (\nabla \cdot u_t)(x_t) \dd{t}. \end{equation*}\]In words: we’re just performing regression on $u_t(x)$ for all $t \in [0, 1]$.
Of course, this requires knowledge of a valid $u(t, x)$, and if we already have access to $u_t$, there’s no point in learning an approximation $u_{\theta}(t, x)$ in the first place! But as we will see in the next section, we can leverage this formulation to construct a useful target for $u_{\theta}(t, x)$ witout having to compute explicitly $u(t, x)$.
This is where Conditional Flow Matching (CFM) comes to the rescue.
We say a valid $u_t$ because there is no unique vector field $u_t$; there are indeed many valid choices for $u_t$ inducing maps $p_0 \overset{\phi}{\longleftrightarrow} p_1$ as illustrated in the figure below. As we will see in what follows, in practice we have to pick a particular target $u_t$, which has practical implications.
First, let’s remind ourselves that the transport equation relates a vector field $u_t$ to (the time evolution of) a probability path $p_t$
\[\begin{equation*} \pdv{p_t(x)}{t} = - \nabla \cdot \big( u_t(x) p_t(x) \big), \end{equation*}\]thus constructing $p_t$ or $u_t$ is equivalent. One key idea (Lipman et al., 2023 and Albergo & Vanden-Eijnden, 2022) is to express the probability path as a marginal over a joint involving a latent variable $z$: $p_t(x_t) = \int p(z) ~p_{t\mid z}(x_t\mid z) \textrm{d}z$. The $p_{t\mid z}(x_t\mid z)$ term being a conditional probability path, satisfying some boundary conditions at $t=0$ and $t=1$ so that $p_t$ be a valid path interpolating between $q_0$ and $q_1$. In addition, as opposed to the marginal $p_t$ , the conditional $p_{t\mid1}$ could be available in closed-form.
In particular, as we have access to data samples $x_1 \sim q_1$, it sounds pretty reasonable to condition on $z=x_1$, leading to the following marginal probabilithy path
\[\begin{equation*} p_t(x_t) = \int q_1(x_1) ~p_{t\mid 1}(x_t\mid x_1) \dd{x_1}. \end{equation*}\]In this setting, the conditional probability path $p_{t\mid 1}$ needs to satisfy the boundary conditions
\[\begin{equation*} p_0(x \mid x_1) = p_0 \quad \text{and} \quad p_1(x \mid x_1) = \mathcal{N}(x; x_1, \sigmamin^2 I) \xrightarrow[\sigmamin \rightarrow 0]{} \delta_{x_1}(x) \end{equation*}\]with $\sigmamin > 0$ small, and for whatever reference $p_0$ we choose, typically something “simple” like $p_0(x) = \mathcal{N}(x; 0, I)$, as illustrated in the figure below.
The conditional probability path also satisfies the transport equation with the conditional vector field $u_t(x \mid x_1)$:
\[\begin{equation} \label{eq:continuity-cond} \pdv{p_t(x \mid x_1)}{t} = - \nabla \cdot \big( u_t(x \mid x_1) p_t(x \mid x_1) \big). \end{equation}\]Lipman et al. (2023) introduced the notion of Conditional Flow Matching (CFM) by noticing that this conditional vector field $u_t(x \mid x_1)$ can express the marginal vector $u_t(x)$ of interest via the conditional probability path $p_{t\mid 1}(x_t\mid x_1)$ as
\[\begin{equation} \label{eq:cf-from-cond-vf} \begin{split} u_t(x) &= \mathbb{E}_{x_1 \sim p_{1 \mid t}} \left[ u_t(x \mid x_1) \right] \\ &= \int u_t(x \mid x_1) \frac{p_t(x \mid x_1) q_1(x_1)}{p_t(x)} \dd{x}_1. \end{split} \end{equation}\]To see why this $u_t$ the same the vector field as the one defined earlier, i.e. the one generating the (marginal) probability path $p_t$, we need to show that the expression above for the marginal vector field $u_t(x)$ satisfies the transport equation
\[\begin{equation*} \pdv{\hlthree{p_t(x)}}{t} = - \nabla \cdot \big( \hltwo{u_t(x)} \hlthree{p_t(x)} \big). \end{equation*}\]Writing out the left-hand side, we have
\[\begin{equation*} \begin{split} \pdv{\hlthree{p_t(x)}}{t} &= \pdv{t} \int p_t(x \mid x_1) q(x_1) \dd{x_1} \\ &= \int \hlone{\pdv{t} \big( p_t(x \mid x_1) \big)} q(x_1) \dd{x_1} \\ &= - \int \hlone{\nabla \cdot \big( u_t(x \mid x_1) p_t(x \mid x_1) \big)} q(x_1) \dd{x_1} \\ &= - \int \hlfour{\nabla} \cdot \big( u_t(x \mid x_1) p_t(x \mid x_1) q(x_1) \big) \dd{x_1} \\ &= - \hlfour{\nabla} \cdot \int u_t(x \mid x_1) p_t(x \mid x_1) q(x_1) \dd{x_1} \\ &= - \nabla \cdot \bigg( \int u_t(x \mid x_1) \frac{p_t(x \mid x_1) q(x_1)}{\hlthree{p_t(x)}} {\hlthree{p_t(x)}} \dd{x_1} \bigg) \\ &= - \nabla \cdot \bigg( {\hltwo{\int u_t(x \mid x_1) \frac{p_t(x \mid x_1) q(x_1)}{p_t(x)} \dd{x_1}}} \ {\hlthree{p_t(x)}} \bigg) \\ &= - \nabla \cdot \big( \hltwo{u_t(x)} {\hlthree{p_t(x)}} \big) \end{split} \end{equation*}\]where in the $\hlone{\text{first highlighted step}}$ we used \eqref{eq:continuity-cond} and in the $\hltwo{\text{last highlighted step}}$ we used the expression of $u_t(x)$ in \eqref{eq:cf-from-cond-vf}.
The relation between $\phi_t(x_0)$, $\phi_t(x_0 \mid x_1)$ and their induced densities are illustrated in the Figure 9 below. And since $\phi_t(x_0)$ and $\phi_t(x_0 \mid x_1)$ are solutions corresponding to the vector fields $u_t(x)$ and $u_t(x \mid x_1)$ with $x(0) = x_0$, Figure 9 is equivalent to Figure 10, but note the difference in the expectation taken to go from $u_t(x_0 \mid x_1) \longrightarrow u_t(x_0)$ compared to $\phi_t(x_0 \mid x_1) \longrightarrow \phi_t(x_0)$.
Let’s try to gain some intuition behind \eqref{eq:cf-from-cond-vf} and the relation between $u_t(x)$ and $u_t(x \mid x_1)$. We do so by looking at the following scenario
\[\begin{equation} \tag{G-to-G} \label{eq:g2g} \begin{split} p_0 = \mathcal{N}([-\mu, 0], I) \quad & \text{and} \quad p_1 = \mathcal{N}([+\mu, 0], I) \\ \text{with} \quad \phi_t(x_0 \mid x_1) &= (1 - t) x_0 + t x_1 \end{split} \end{equation}\]with $\mu = 10$ unless otherwise specified. We’re effectively transforming a Gaussian to another Gaussian using a simple time-linear map, as illustrated in the following figure.
In the end, we’re really just interested in learning the marginal paths $\phi_t(x_0)$ for initial points $x_0$ that are probable under $p_0$, which we can then use to generate samples $x_1 = \phi_1(x_0)$. In this simple example, we can obain closed-form expressions for $\phi_t(x_0)$ corresponding to the conditional paths $\phi_t(x_0 \mid x_1)$ of \eqref{eq:g2g}, as visualised below.
With that in mind, let’s pick a random initial point $x_0$ from $p_0$, and then compare a MC estimator for $u_t(x_0)$ at different values of $t$ along the path $\phi_t(x_0)$, i.e. we’ll be looking at
\[\begin{equation*} \begin{split} u_t \big( \phi_t(x_0) \big) &= \E_{p_{1 \mid t}}\left[u_t \big( \phi_t(x_0) \mid x_1 \big)\right] \\ &\approx \frac{1}{n} \sum_{i = 1}^n u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big) \ \text{with } x_1^{(i)} \sim p_{1 \mid t}(x_1 \mid \phi_t(x_0)). \end{split} \end{equation*}\]In practice we don’t have access to the posterior \(p_{1 \mid t}(x_1 \mid x_t)\), but in this specific setting we do have closed-form expressions for everything (Albergo & Vanden-Eijnden, 2022), and so we can visualise the marginal vector field \(u_t\big( \phi_t(x_0)\big)\) and the conditional vector fields \(u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big)\) for all our “data” samples \(x_1^{(i)}\) and see how they compare. This is shown in the figure below.
From the above figures, we can immediately see how for small $t$, i.e. near 0, the posterior $p_{1 \mid t}(x_1 \mid x_t)$ is quite scattered so the marginalisation giving $u_t$ involves many equally likely data samples $x_1$. In contrast, when $t$ increases and get closer to 1, $p_{1 \mid t}(x_1 \mid x_t)$ gets quite concentrated over much fewer samples $x_1$.
Moreover, equipped with the knowledge of \eqref{eq:cf-from-cond-vf}, we can replace
\[\begin{align} \mathcal{L}_{\mathrm{FM}}(\theta) = \mathbb{E}_{t \sim \mathcal{U}[0, 1], x \sim p_t}\left[\| u_\theta(t, x) - u(t, x) \|^2 \right], \end{align}\]where $u_t(x) = \mathbb{E}_{x_1 \sim p_{1 \mid t}} \left[ u_t(x \mid x_1) \right]$, with an equivalent loss regressing the conditional vector field $u_t(x \mid x_1)$ and marginalising $x_1$ instead:
\[\begin{equation*} \mathcal{L}_{\mathrm{CFM}}(\theta) = \mathbb{E}_{t \sim \mathcal{U}[0, 1], x_1 \sim q, x_t \sim p_t(x \mid x_1)}\left[\| u_\theta(t, x) - u_t(x \mid x_1) \|^2 \right]. \end{equation*}\]These losses are equivalent in the sense that
\[\begin{equation*} \nabla_\theta \mathcal{L}_{\mathrm{FM}}(\theta) = \nabla_\theta \mathcal{L}_{\mathrm{CFM}}(\theta), \end{equation*}\]which implies that we can use \({\mathcal{L}}_{\text{CFM}}\) instead to train the parametric vector field $u_{\theta}$. We defer the full proof to the footnote^{9}, but show the key idea below. By developing the squared norm in both losses, we can easily show that the squared terms are equal or independent of $\theta$. Let’s develop the inner product term for \({\mathcal{L}}_{\text{FM}}\) and show that it is equal to the inner product of \({\mathcal{L}}_{\text{CFM}}\):
\[\begin{align} \mathbb{E}_{x \sim p_t} ~\langle u_\theta(t, x), \hltwo{u_t(x)} \rangle &= \int \langle u_\theta(t, x), \hltwo{\int} u_t(x \mid x_1) \hltwo{\frac{p_t(x \mid x_1)q(x_1)}{p_t(x)} dx_1} \rangle p_t(x) \mathrm{d} x \\ &= \int \langle u_\theta(t, x), \int u_t(x \mid x_1) p_t(x \mid x_1)q(x_1) dx_1 \rangle \dd{x} \\ &= \int \int \langle u_\theta(t, x), u_t(x \mid x_1) \rangle p_t(x \mid x_1)q(x_1) dx_1 \dd{x} \\ &= \mathbb{E}_{q_1(x_1) p(x \mid x_1)} ~\langle u_\theta(t, x), u_t(x \mid x_1) \rangle \end{align}\]where in the $\hltwo{\text{first highlighted step}}$ we used the expression of $u_t(x)$ in \eqref{eq:cf-from-cond-vf}.
The benefit of the CFM loss is that once we define the conditional probability path $p_t(x \mid x_1)$, we can construct an unbiased Monte Carlo estimator of the objective using samples $\big( x_1^{(i)} \big)_{i = 1}^n$ from the data target $q_1$!
This estimator can be efficiently computed as it involves an expectation over the joint $q_1(x_1)p_t(x \mid x_1)$ , of the conditional vector field $u_t (x \mid x_1)$ both being available as opposed to the marginal vector field $u_t$ which involves an expectation over the posterior $p_{1 \mid t}(x_1 \mid x)$.
We note that, as opposed to the log-likelihood maximisation loss of CNFs which does not put any preference over which vector field $u_t$ can be learned, the CFM loss does specify one via the choice of a conditional vector field, which will be regressed by the neural vector field $u_\theta$.
Let’s now look at practical example of conditional vector field and the corresponding probability path. Suppose we want conditional vector field which generates a path of Gaussians, i.e.
\[\begin{equation*} p_t(x \mid x_1) = \mathcal{N}(x; \mu_t(x_1), \sigma_t(x_1)^2 \mathrm{I}) \end{equation*}\]for some mean $\mu_t(x_1)$ and standard deviation $\sigma_t(x_1)$.
One conditional vector field inducing the above-defined conditional probability path is given by the following expression:
\[\begin{equation} \label{eq:gaussian-path} u_t(x \mid x_1) = \frac{\dot{\sigma_t}(x_1)}{\sigma_t(x_1)} (x - \mu_t(x_1)) + \dot{\mu_t}(x_1) \end{equation}\]as shown in the proof below.
A simple choice for the mean $\mu_t(x_1)$ and std. $\sigma_t(x_1)$ is the linear interpolation for both, i.e.
\[\begin{align*} \hlone{\mu_t(x_1)} &\triangleq t x_1 \quad &\text{and} \quad \hlthree{\sigma_t(x_1)} &\triangleq (1 - t) + t \sigmamin \\ \hltwo{\dot{\mu}_t(x_1)} &\triangleq x_1 \quad &\text{and} \quad \hlfour{\dot{\sigma}_t(x_1)} &\triangleq -1 + \sigmamin \end{align*}\]so that
\[\begin{equation*} \big( {\hlone{\mu_0(x_1)}} + {\hlthree{\sigma_0(x_1)}} x_1 \big) \sim p_0 \quad \text{and} \quad \big( {\hlone{\mu_1(x_1)}} + {\hlthree{\sigma_1(x_1)}} x_1 \big) \sim \mathcal{N}(x_1, \sigmamin^2 I) \end{equation*}\]In addition, letting $p_0 = \mathcal{N}([-\mu, 0], I)$ and $p_1 = \mathcal{N}([+\mu, 0], I)$ for some $\mu > 0$, we’re back to the \ref{eq:g2g} example from earlier.
We can then plug this choice of $\mu_t(x_1)$ and $\sigma_t(x_1)$ into \eqref{eq:gaussian-path} to obtain the conditional vector field, writing $\hlthree{\sigma_t(x_1)} = 1 - (1 - \sigmamin) t$ to make our lives simpler,
\[\begin{equation*} \begin{split} u_t(x \mid x_1) &= \frac{\hlfour{- (1 - \sigmamin)}}{\hlthree{1 - (1 - \sigmamin) t}} (x - \hlone{t x_1}) + \hltwo{x_1} \\ &= \frac{1}{(1 - t) + t \sigmamin} \bigg( - (1 - \sigmamin) (x - t x_1) + \big(1 - (1 - \sigmamin) t \big) x_1 \bigg) \\ &= \frac{1}{(1 - t) + t \sigmamin} \bigg( - (1 - \sigmamin) x + x_1 \bigg) \\ &= \frac{x_1 - (1 - \sigmamin) x}{1 - (1 - \sigmamin) t}. \end{split} \end{equation*}\]Below you can see the difference between $\phi_t(x_0)$ (top figure) and $\phi_t(x_0 \mid x_1)$ (bottom figure) for pairs $(x_0, x_1)$ with $x_0 \sim p_0$ and $x_1 = \phi_t(x_0)$. The paths are coloured by the sign of the 2nd vector component of $x_0$ to more clearly highlight the difference between the marginal and conditional flows.
Unfortunately not, no. There are two issues arising from crossing conditional paths. We will explain this just after, but now we stress that this leads to
To get a better understanding of what these two points above, let’s revisit the \ref{eq:g2g} example once more. As we see in the figures below, realizations of the conditional vector field $u_t(x \mid x_1)$, i.e. sampling from the process
\[\begin{equation*} \begin{split} x_1 & \sim q \\ x_t & \triangleq \phi_t(x \mid x_1) \end{split} \end{equation*}\]result in paths that are quite different from the marginal paths as illustrated in the figures below.
In particular, we can see that the marginal paths $\phi_t(x)$ do not cross; this is indeed just the uniqueness property of ODE solutions. A realization of the conditional vector field $u_t(x \mid x_1)$ also exhibits the “non-crossing paths” property, similar to the marginal flows $\phi_t(x)$, however paths $\phi_t(x \mid x_1)$ corresponding to different realizations $x_1 \sim q_1$ may intersect, as highlighted in the figure above.
Consider two highlighted paths in the visualization of $u_t(x \mid x_1)$, with data samples $\hlone{x_1^{(1)}}$ and $\hlthree{x_1^{(2)}}$. When learning a parameterized vector field $u_{\theta}(t, x)$ via stochastic gradient descent (SGD), we approximate the CFM loss as:
\[\mathcal{L}_{\mathrm{CFM}}(\theta) \approx \frac{1}{2} \norm{u_{\theta}(t, \hlone{x_t^{(1)}}) - u(t, \hlone{x_t^{(1)}} \mid \hlone{x_1^{(1)}})} + \frac{1}{2} \norm{u_{\theta}(t, \hlthree{x_t^{(2)}}) - u(t, \hlthree{x_t^{(2)}} \mid \hlthree{x_1^{(2)}})}\]where $t \sim \mathcal{U}[0, 1]$, $\hlone{x_1^{(1)}}, \hlthree{x_1^{(2)}} \sim q_1$, and $\hlone{x_t^{(1)}} \sim p_t(\cdot \mid \hlone{x_1^{(1)}}), \hlthree{x_t^{(2)}} \sim p_t(\cdot \mid \hlthree{x_1^{(2)}})$. We compute the gradient with respect to $\theta$ for a gradient step.
In such a scenario, we’re attempting to align $u_{\theta}(t, x)$ with two different vector fields whose corresponding paths are impossible under the marginal vector field $u(t, x)$ that we’re trying to learn! This fact can lead to increased variance in the gradient estimate, and thus slower convergence.
In slightly more complex scenarios, the situation becomes even more striking. Below we see a nice example from Liu et al. (2022) where our reference and target are two different mixture of Gaussians in 2D differing only by the sign of the mean in the x-component. Specifically,
\[\begin{equation} \tag{MoG-to-MoG} \label{eq:mog2mog} \begin{split} p_{\hlone{0}} &= (1 / 2)\mathcal{N}([{\hlone{-\mu}}, -\mu], I) + (1 / 2) \mathcal{N}([{\hlone{-\mu}}, +\mu], I) \\ \text{and} \quad p_{\hltwo{1}} &= (1 / 2) \mathcal{N}([{\hltwo{+\mu}}, -\mu], I) + (1 / 2) \mathcal{N}([{\hltwo{+\mu}}, +\mu], I) \\ \text{with} \quad \phi_t(x_0 \mid x_1) &= (1 - t) x_0 + t x_1 \end{split} \end{equation}\]where we set $\mu = 10$, unless otherwise specified.
Here we see that marginal paths (bottom figure) end up looking very different from the conditional paths (top figure). Indeed, at training time paths may intersect, whilst at sampling time they cannot (due to the uniqueness of the ODE solution). As such we see on the bottom plot that some (marginal) paths are quite curved and would therefore require a greater number of discretisation steps from the ODE solver during inference.
We can also see how this leads to a significant variance of the CFM loss estimate for $t \approx 0.5$ in the figure below. More generally, samples from the reference distribution which are arbitrarily close to eachothers can be associated with either target modes, leading to high variance in the vector field regression loss.
An intuitive solution would be to associate data samples with reference samples which are close instead of some arbitrary pairing. We’ll detail this idea next via the concept of couplings and optimal transport.
So far we have constructed the vector field $u_t$ by conditioning and marginalising over data points $x_1$. This is referred to as a one-sided conditioning, where the probability path is constructed by marginalising over $z=x_1$:
\[p_t(x_t) = \int p_t(x_t \mid z) q(z) \dd{z} = \int p_t(x_t \mid x_1) q(x_1) \dd{x_1}\]e.g. \(p(x_t \mid x_1) = \mathcal{N}(x_t \mid x_1, (1-t)^2)\).
Yet, more generally, we can consider conditioning and marginalising over latent variables $z$, and minimising the following loss:
\[\mathcal{L}_{\mathrm{CFM}}(\theta) = \mathbb{E}_{(t,z,x_t) \sim \mathcal{U}[0,1] q(z) p(\cdot \mid z)}[\| u_\theta(t, x_t) - u_t(x_t \mid z)\|^2].\]As suggested in Liu et al. (2023), Tong et al. (2023), Albergo & Vanden-Eijnden (2022) and Pooladian et al. (2023) one can condition on both endpoints $z=(x_1, x_0)$ of the process, referred as two-sided conditioning. The marginal probability path is defined as:
\[p_t(x_t) = \int p_t(x_t \mid z) q(z) \dd{z} = \int p_t(x_t \mid x_1, x_0) q(x_1, x_0) \dd{x_1} \dd{x_0}.\]The following boundary condition on $p_t(x_t \mid x_1, x_0)$: $p_0(\cdot \mid x_1, x_0)=\delta_{x_0}$ and $p_1(\cdot \mid x_1, x_0) = \delta_{x_1}$ is required so that the marginal has the proper conditions $p_0 = q_0$ and $p_1 = q_1$.
For instance, a deterministic linear interpolation gives $p(x_t \mid x_0, x_1) = \delta_{(1-t)} x_0 + t x_1(x_t)$ and the simplest choice regarding the coupling $z = (x_1, x_0)$ is to consider independent samples: $q(x_1, x_0) = q_1(x_1) q_0(x_0)$.
One main advantage is that this allows for non Gaussian reference distribution $q_0$. Choosing a standard normal as noise distribution $q(x_0) = \mathcal{N}(0, \mathrm{I})$ we recover the same one-sided conditional probability path as earlier:
\[p(x_t \mid x_1) = \int p(x_t \mid x_0, x_1) q(x_0) \dd{x_0} = \mathcal{N}(x_t \mid tx_1, (1-t)^2).\]Now let’s go back to the idea of not using an independent coupling (i.e. pairing) but instead to correlate pairs $(x_1, x_0)$ with a joint $q(x_1, x_0) \neq q_1(x_1) q_0(x_0)$. Tong et al. (2023) and Pooladian et al. (2023) suggest using the optimal transport coupling
\[\begin{equation} \tag{OT} \label{eq:ot} q(x_1, x_0) = \pi(x_1, x_0) \in \arg\inf_{\pi \in \Pi} \int \|x_1 - x_0\|_2^2 \mathrm{d} \pi(x_1, x_0) \end{equation}\]which minimises the optimal transport (i.e. Wasserstein) cost (Monge, 1781, Peyré and Cuturi 2020). The OT coupling $\pi$ associates samples $x_0$ and $x_1$ such that the total distance is minimised.
This OT coupling is illustrated in the right hand side of the figure below, adapted from Tong et al. (2023). In contrast to the middle figure which an independent coupling, the OT one does not have paths that cross. This leads to lower training variance and faster sampling^{10}.
In practice, we cannot compute the optimal coupling $\pi$ between $x_1 \sim q_1$ and $x_0 \sim q_0$, as algorithms solving this problem are only known for finite distributions. In fact, finding a map from $q_0$ to $q_1$ is the generative modelling problem that we are trying to solve in the first place!
Tong et al. (2023) and Pooladian et al. (2023) propose to approximate the OT coupling $\pi$ by computing such optimal coupling only over each mini-batch of data and noise samples, coined mini-batch OT (Fatras et al., 2020). This is scalable as for finite collection of samples the OT problem can be computed with quadratic complexity via the Sinkhorn algorithm (Peyre and Cuturi, 2020). This results in a joint distribution $\gamma(i, j)$ over “inputs” \(\big(x_0^{(i)}\big)_{i=1,\dots,B}\) and “outputs” \(\big(x_1^{(j)}\big)_{j=1,\dots,B}\) such that the expected distance is (approximately) minimised. Finally, to construct a mini-batch from this $\gamma$ which we can subsequently use for training, we can either compute the expectation wrt. $\gamma(i, j)$ by considering all $n^2$ pairs (in practice, this can often boil down to only needing to consider $n$ disjoint pairs^{11}) or sample a new collection of training pairs $(x_0^{(i’)}, x_1^{(j’)})$ with $(i’, j’) \sim \gamma$^{12}.
For example, we can apply this to the \eqref{eq:g2g} example from before, which almost completely removes the crossing paths behaviour described earlier, as can be seen in the figure below.
We also observe similar behavior when applying this the more complex example \eqref{eq:mog2mog}, as can be seen in the figure below.
All in all, making use of mini-batch OT seems to be a strict improvement over the uniform sampling approach to constructing the mini-batch in the above examples and has been shown to improve practical performance in a wide range of applications (Tong et al., 2023; Klein et al., 2023).
It’s worth noting that in \eqref{eq:ot} we only considered choosing the coupling $\gamma(i, j)$ such that we minimize the expected squared Euclidean distance. This works well in the examples \eqref{eq:g2g} and \eqref{eq:mog2mog}, but we could also replace squared Euclidean distance with some other distance metric when constructing the coupling $\gamma(i, j)$. For example, if we were modeling molecules using CNFs, it might also make sense to pick $(i, j)$ such that $x_0^{(i)}$ and $x_1^{(j)}$ are also rotationally aligned as is done in the work of Klein et al. (2023).
In short, we’ve shown that flow matching is an efficient approach to training continuous normalising flows (CNFs), by directly regressing over the vector field instead of explicitly training by maximum likelihood. This is enabled by constructing the target vector field as the marginalisation of simple conditional vector fields which (marginally) interpolate between the reference and data distribution, but crucially for which we can evaluate and integrate over time. A neural network parameterising the vector field can then be trained by regressing over these conditional vector fields. Similarly to CNFs, samples can be obtained at inference time by solving the ODE associated with the neural vector field.
In this post we have not talked about diffusion (i.e. score based) models on purpose as they are not necessary for understanding flow matching. Yet these are deeply related and even exactly the same in some settings. We are planning to explore these connections, along with generalisations in a follow-up post!
Please cite us as:
@misc{mathieu2024flow,
title = "An Introduction to Flow Matching",
author = "Fjelde, Tor and Mathieu, Emile and Dutordoir, Vincent",
journal = "https://mlg.eng.cam.ac.uk/blog/",
year = "2024",
month = "January",
url = "https://mlg.eng.cam.ac.uk/blog/2024/01/20/flow-matching.html"
}
We deeply thank Michael Albergo, Valentin Debortoli and James Thornton for giving insightful feedback! And thank you to Andrew Foong for pointing out several typos and rendering issues!
Albergo, Michael S. & Vanden-Eijnden, Eric (2023) Building Normalizing Flows with Stochastic Interpolants.
Behrmann, Jens and Grathwohl, Will and Chen, Ricky T. Q. and Duvenaud, David and Jacobsen, Joern-Henrik (2019). Invertible Residual Networks.
Betker, James, Gabriel Goh, Li Jing, TimBrooks, Jianfeng Wang, Linjie Li, LongOuyang, JuntangZhuang, JoyceLee, YufeiGuo, WesamManassra, PrafullaDhariwal, CaseyChu, YunxinJiao and Aditya Ramesh (2023). Improving Image Generation with Better Captions.
Chen & Gopinath (2000). Gaussianization.
Chen & Lipman (2023). Riemannian Flow Matching on General Geometries.
Chen, Ricky T. Q. and Behrmann, Jens and Duvenaud, David K and Jacobsen, Joern-Henrik (2019). Residual flows for invertible generative modeling.
De Bortoli, Mathieu & Hutchinson et al. (2022). Riemannian Score-Based Generative Modelling.
Dupont, Doucet & Teh (2019). Augmented Neural Odes.
Friedman (1987). Exploratory projection pursuit.
George Papamakarios, Theo Pavlakou, Iain Murray (2018). Masked Autoregressive Flow for Density Estimation.
Huang, Chin-Wei and Krueger, David and Lacoste, Alexandre and Courville, Aaron (2018). Neural Autoregressive Flows.
Klein, Krämer & Noé (2023). Equivariant Flow Matching.
Lipman, Yaron and Chen, Ricky T. Q. and Ben-Hamu, Heli and Nickel, Maximilian and Le, Matt (2022). Flow Matching for Generative Modeling.
Liu, Xingchao and Gong, Chengyue and Liu, Qiang (2022). Flow Straight and Fast: Learning to Generate and Transfer Data with Rectified Flow.
Monge, Gaspard (1781). Mémoire Sur La Théorie Des Déblais et Des Remblais.
Peyré, Gabriel and Cuturi, Marco (2020). Computational Optimal Transport.
Pooladian, Aram-Alexandre and {Ben-Hamu}, Heli and {Domingo-Enrich}, Carles and Amos, Brandon and Lipman, Yaron and Chen, Ricky T. Q. (2023). Multisample Flow Matching: Straightening Flows With Minibatch Couplings.
Song, Sohl-Dickstein & Kingma et al. (2020). Score-Based Generative Modeling Through Stochastic Differential Equations.
Tong, Alexander and Malkin, Nikolay and Fatras, Kilian and Atanackovic, Lazar and Zhang, Yanlei and Huguet, Guillaume and Wolf, Guy and Bengio, Yoshua (2023). Simulation-Free Schrodinger Bridges via Score and Flow Matching.
Tong, Malkin & Huguet et al. (2023). Improving and Generalizing Flow-Based Generative Models With Minibatch Optimal Transport.
Watson, Joseph L. and Juergens, David and Bennett, Nathaniel R. and Trippe, Brian L. and Yim, Jason and Eisenach, Helen E. and Ahern, Woody and Borst, Andrew J. and Ragotte, Robert J. and Milles, Lukas F. and Wicky, Basile I. M. and Hanikel, Nikita and Pellock, Samuel J. and Courbet, Alexis and Sheffler, William and Wang, Jue and Venkatesh, Preetham and Sappington, Isaac and Torres, Susana V{'a}zquez and Lauko, Anna and De Bortoli, Valentin and Mathieu, Emile and Ovchinnikov, Sergey and Barzilay, Regina and Jaakkola, Tommi S. and DiMaio, Frank and Baek, Minkyung and Baker, David (2023). De Novo Design of Protein Structure and Function with RFdiffusion.
The property $\phi \circ \phi^{-1} = \Id$ implies, by the chain rule, \(\begin{equation*} \pdv{\phi}{x} \bigg|_{x = \phi^{-1}(y)} \pdv{\phi^{-1}}{y} \bigg|_{y} = 0 \iff \pdv{\phi}{x} \bigg|_{x = \phi^{-1}(y)} = \bigg( \pdv{\phi^{-1}}{y} \bigg|_{y} \bigg)^{-1} \quad \forall y \in \mathbb{R}^d \end{equation*}\) ↩
Autoregressive (Papamakarios et al., 2018; Huang et al., 2018) One strategy is to factor the flow’s Jacobian to have a triangular structure by factorising the density as $p_\theta(x) = \prod_{d} p_\theta(x_d;x_{d<})$ with each conditional $p_\theta(x_d;x_{d<})$ being induced via a flow. Low rank residual (Van Den Berg et al., 2018) Another approach is to construct a flow via a residual connection: \(\begin{equation*} \phi(x) = x + A h(B x + b) \end{equation*}\) with parameters $A \in \R^{d\times m}$, $B \in \R^{ m\times m}$ and $b \in \R^m$. Leveraging Sylvester’s determinant identity $\det(I_d + AB)=\det(I_m + BA)$, the determinant computation can be reduced to one of a $m \times m$ matrix which is advantageous if $m \mathrm{«} d$. ↩
A sufficient condition for $\phi_k$ to be invertible is for $u_k$ to be $1/h$-Lipschitz [Behrmann et al., 2019]. The inverse $\phi_k^{-1}$ can be approximated via ﬁxed-point iteration (Chen et al., 2019). ↩
A sufficient condition for $\phi_t$ to be invertible is for $u_t$ to be Lipschitz and continuous by Picard–Lindelöf theorem. ↩
The Fokker–Planck equation gives the time evolution of the density induced by a stochastic process. For ODEs where the diffusion term is zero, one recovers the transport equation. ↩
Expanding the divergence in the transport equation we have: \(\begin{equation*} \frac{\partial}{\partial_t} p_t(x_t) = - (\nabla \cdot (u_t p_t))(x_t) = - p_t(x_t) (\nabla \cdot u_t)(x_t) - \langle \nabla_{x_t} p_t(x_t), u_t(x_t) \rangle. \end{equation*}\) Yet since $x_t$ also depends on $t$, to get the total derivative we have \(\begin{align} \frac{\dd}{\dd t} p_t(x_t) &= \frac{\partial}{\partial_t} p_t(x_t) + \langle \nabla_{x_t} p_t(x_t), \frac{\dd}{\dd t} x_t \rangle \\ &= - p_t(x_t) (\nabla \cdot u_t)(x_t) - \langle \nabla_{x_t} p_t(x_t), u_t(x_t) \rangle + \langle \nabla_{x_t} p_t(x_t), \frac{\dd}{\dd t} x_t \rangle \\ &= - p_t(x_t) (\nabla \cdot u_t)(x_t). \end{align}\) Where the last step comes from $\frac{\dd}{\dd t} x_t = u_t$. Hence, $\frac{\dd}{\dd t} \log p_t(x_t) = \frac{1}{p_t(x_t)} \frac{\dd}{\dd t} p_t(x_t) = - (\nabla \cdot u_t)(x_t).$ ↩
The Skilling-Hutchinson trace estimator is given by $\Tr(A) = \E[v^\top A v]$ with $v \sim p$ isotropic and centred. In our setting we are interested in $\div(u_t)(x) = \Tr(\frac{\partial u_t(x)}{\partial x}) = \E[v^\top \frac{\partial u_t(x)}{\partial x} v]$ which can be approximated with a Monte-Carlo estimator, where the integrand is computed via automatic forward or backward differentiation. ↩
The top row is with reference $p_0 = \mathcal{N}([-a, 0], I)$ and target $p_1 = (1/2) \mathcal{N}([a, -10], I) + (1 / 2) \mathcal{N}([a, 10], I)$, and the bottom row is the \ref{eq:g2g} example. The left column shows the straight-line solutions for the marginals and the right column shows the marginal solutions induced by considering the straight-line conditional interpolants. ↩
Developing the square in both losses we get: \(\|u_\theta(t, x) - u_t(x \mid x_1)\|^2 = \|u_\theta(t, x)\|^2 + \|u_t(x \mid x_1)\|^2 - 2 \langle u_\theta(t, x), u_t(x \mid x_1) \rangle,\) and \(\|u_\theta(t, x) - u_t(x)\|^2 = \|u_\theta(t, x)\|^2 + \|u_t(x)\|^2 - 2 \langle u_\theta(t, x), u_t(x) \rangle.\) Taking the expectation over the last inner product term: \(\begin{align} \mathbb{E}_{x \sim p_t} ~\langle u_\theta(t, x), u_t(x) \rangle &= \int \langle u_\theta(t, x), \int u_t(x|x_1) \frac{p_t(x \mid x_1)q(x_1)}{p_t(x)} dx_1 \rangle p_t(x) \dd{x} \\ &= \int \langle u_\theta(t, x), \int u_t(x \mid x_1) p_t(x \mid x_1)q(x_1) dx_1 \rangle \dd{x} \\ &= \int \int \langle u_\theta(t, x), u_t(x \mid x_1) \rangle p_t(x \mid x_1)q(x_1) dx_1 \dd{x} \\ &= \mathbb{E}_{q_1(x_1) p(x \mid x_1)} ~\langle u_\theta(t, x), u_t(x \mid x_1) \rangle. \end{align}\) Then we see that the neural network squared norm terms are equal since: \(\mathbb{E}_{p_t} \|u_\theta(t, x)\|^2 = \int \|u_\theta(t, x)\|^2 p_t(x \mid x_1) q(x_1) \dd{x} \dd{x_1} = \mathbb{E}_{q_1(x_1) p(x \mid x_1)} \|u_\theta(t, x)\|^2\) ↩
Dynamic optimal transport [Benamou and Brenier, 2000] \(W(q_0, q_1)_2^2 = \inf_{p_t, u_t} \int \int_0^1 \|u_t(x)\|^2 p_t(x) \dd{t} \dd{x}\) ↩
In mini-batch OT, we only work with the empirical distributions over $x_0^{(i)}$ and $x_1^{(j)}$, i.e. they all have weights $1 / n$, where $n$ is the size of the mini-batch. This means that we can find a $\gamma$ matching the $\inf$ in \eqref{eq:ot} by solving what’s referred to as a linear assignment problem. This results in a sparse matrix with exactly $n$ entries, each then with a weight of $1 / n$. In such a scenario, computing the expectation over the joint $\gamma(i, j)$, which has $n^2$ entries but in this case only $n$ non-zero entries, can be done by only considering $n$ training pairs where every $i$ is involved in exactly one pair and similarly for every $j$. This is usally what’s done in practice. When solving the assignment problem is too computationally intensive, using Sinkhorn and a sampling from the coupling might be the preferable approach. ↩
Note the size of the resulting mini-batch sampled from $\gamma(i, j)$ does not necessarily have to be of the same size as the mini-batch size used to construct the mini-batch OT approximation as we can sample from $\gamma$ with replacement, but using the same size is typically done in practice, e.g. Tong et al. (2023). ↩
We tackle these questions in this second part of the natural-gradient for variational inference series. We show that we can get good performance at large scales with Bayesian principles, while maintaining reasonable uncertainties. We start by focussing on question (i): the issue of scalability. We notice similarities between our NGVI algorithm and Adam, and exploit this to borrow tricks that the community has developed for Adam over many years. This allows us to scale up to very large datasets/architectures. We then turn our focus to question (ii): have we improved on neural networks’ poorly-calibrated uncertainties thanks to our Bayesian thinking? We will see some benefits. Along the way, we will discuss the price we pay for them.
This second part of the blog closely follows a paper I was involved in, Practical Deep Learning with Bayesian Principles (Osawa et al., 2019). There is also a codebase if you are interested in experimenting with our algorithm, VOGN (Variational Online Gauss-Newton). As a postscript to this blog post, we summarise some good practices for training your own neural network with VOGN. $\newcommand{\vparam}{\boldsymbol{\theta}}$ $\newcommand{\veta}{\boldsymbol{\eta}}$ $\newcommand{\vphi}{\boldsymbol{\phi}}$ $\newcommand{\vmu}{\boldsymbol{\mu}}$ $\newcommand{\vSigma}{\boldsymbol{\Sigma}}$ $\newcommand{\vm}{\mathbf{m}}$ $\newcommand{\vF}{\mathbf{F}}$ $\newcommand{\vI}{\mathbf{I}}$ $\newcommand{\vg}{\mathbf{g}}$ $\newcommand{\vH}{\mathbf{H}}$ $\newcommand{\vs}{\mathbf{s}}$ $\newcommand{\myexpect}{\mathbb{E}}$ $\newcommand{\pipe}{\,|\,}$ $\newcommand{\data}{\mathcal{D}}$ $\newcommand{\loss}{\mathcal{L}}$ $\newcommand{\gauss}{\mathcal{N}}$
We start with the equations for the VOGN algorithm, derived in our previous blog post. This also serves as a quick summary of notation: please look at the previous blog post if anything is unclear! (Colours are purely for illustrative purposes.)
\begin{align} \label{eq:VOGN_mu} \vmu_{t+1} &= \vmu_t - \alpha_t \frac{ {\color{purple}\hat{\vg}(\vparam_t)} + {\color{blue}\tilde{\delta}}\vmu_t}{\vs_{t+1} + {\color{blue}\tilde{\delta}}}, \newline \label{eq:VOGN_Sigma} \vs_{t+1} &= (1-\beta_t)\vs_t + \beta_t \frac{1}{M} \sum_{i\in\mathcal{M}_t}\left( {\color{purple}\vg_i(\vparam_t)} \right)^2. \end{align}
Remember, we are updating the parameters of our (mean-field Gaussian) approximate posterior $q_t(\vparam)=\gauss(\vparam; \vmu_t, \vSigma_t)$, where $\vparam$ are the parameters of a neural network. We do this by iteratively updating two vectors, $\vmu_t$ and $\vs_t$, where $t$ indexes the iteration. We have a zero-mean prior $p(\vparam) = \gauss(\vparam; \color{blue}\boldsymbol{0}, {\color{blue}\delta^{-1}} \vI)$, and ${\color{blue}\tilde{\delta}} = {\color{blue}{\delta}} / N$. Our dataset consists of $N$ data examples, and we are taking per-example gradients of the negative log-likelihood $\color{purple}\vg_i(\vparam_t)$ at a sample from our current approximate posterior, $\vparam_t \sim q_t(\vparam)$. For a randomly-sampled minibatch $\mathcal{M}_t$ of size $M$, we have defined the average gradient ${\color{purple}\hat{\vg}(\vparam_t)} = \frac{1}{M}\sum_{i\in\mathcal{M}_t} {\color{purple}\vg_i(\vparam_t)}$. There is a simple relation between $\vSigma_t$ and $\vs_t$, $\vSigma_t^{-1} = N\vs_t + {\color{blue}\delta \vI}$. Finally, $\alpha_t>0$ and $0<\beta_t<1$ are learning rates, and all operations are element-wise.
It turns out that this update equation is very similar to Adam (Kingba & Ba, 2015). To see this, let’s write down the form that commonly-used optimisers take, such as SGD, RMSProp (Hinton, 2012), and Adam:
\begin{align} \label{eq:Adam_mu} \vmu_{t+1} &= \vmu_t - \alpha_t \frac{ {\color{purple}\hat{\vg}(\vmu_t)} + {\color{blue}\delta}\vmu_t} {\sqrt{\vs_{t+1}} + \epsilon}, \newline \label{eq:Adam_Sigma} \vs_{t+1} &= (1-\beta_t)\vs_t + \beta_t \left( \frac{1}{M} \sum_{i\in\mathcal{M}_t} {\color{purple}\vg_i(\vmu_t)} + {\color{blue}\delta} \vmu_t \right) ^2, \end{align}
where $\delta>0$ is our weight-decay regulariser, and $\epsilon>0$ is a small scalar constant. Immediately we can see striking similarities in the overall form of the equations! Let’s take a closer look at the similarities and differences:
The Gauss-Newton approximation (Difference 5) is a better approximation to second-order information (Hessian) than the gradient-magnitude approach. This better approximation is likely why VOGN does not require a square root over $\vs_{t+1}$ in the update for the means (Difference 2). However, calculating the Gauss-Newton approximation requires additional computation in frameworks such as PyTorch, leading to VOGN being slower (per epoch) compared to Adam. This is despite using speed-up tricks (Goodfellow, 2015).
The similarities in the equations indicate that we might be able to take techniques people use to scale Adam up to large datasets and architectures, and apply similar ideas to scale VOGN up. We can use batch normalisation, momentum, clever initialisations, data augmentation, learning rate scheduling, and so on.
Let’s go over a list of each of the changes we make, providing some intuition for them. Please see Osawa et al. (2019) for further details. Using all these techniques, we are able to scale VOGN to datasets like CIFAR-10 and ImageNet, and architectures such as ResNets.
Batch normalisation (Ioffe & Szegedy, 2015) empirically speeds up and stabilises training of neural networks. We can use BatchNorm layers as is normal in deep learning. In fact, in our VOGN implementation, we found that we do not have to maintain uncertainty over BatchNorm parameters.
We can also use momentum for VOGN in a similar way to Adam: we introduce momentum on $\vmu_t$, along with a momentum rate.
Over many years of training neural networks with SGD and Adam, the community has found tricks to speed up training using clever initialisation. We can get these same benefits by changing VOGN to look more like Adam at initialisation, before slowly relaxing our algorithm to become the full VOGN algorithm later in training.
This is achieved by introducing a tempering parameter $\tau$ in front of the KL term in the ELBO, which propagates its way through to the VOGN equations. To see exactly where $\tau$ crops up, please look at Equation 4 from Osawa et al. (2019), or see Algorithm 1 below. As $\tau\rightarrow 0$, we (loosely speaking) get more similar to Adam. So, at the beginning of training, we initialise $\tau$ at something small (like $0.1$) and increase to $1$ during the first few optimisation steps.
Other initialisations are the same as Adam: $\vmu_t$ is initialised using init.xavier_normal
from PyTorch (Glorot & Bengio, 2010) and the momentum term is initialised to zero, like in Adam. VOGN’s $\vs_t$ is initialised using an additional forward pass through the first minibatch of data.
We can use learning rate scheduling for $\alpha_t$ exactly like is used for Adam and SGD at a large scale. We regularly decay the learning rate by a factor (typically a factor of 10).
When training on image datasets, data augmentation (DA) can improve performance drastically. For example, we can use random cropping and/or random horizontal flipping of images.
Unfortunately, directly applying DA to VOGN does not lead to improvements, and also negatively affects uncertainty calibration. But we note that DA can be viewed as increasing the effective size of our dataset: remember that our dataset size $N$ affects VOGN (as opposed to Adam and SGD, where $N$ does not appear, as it unidentifiable with the weight-decay factor). So, we view DA as increasing $N$ by some factor depending on the exact DA technique: for example, if we horizontally flip each image with a probability of 50%, we increase $N$ by a factor of 2.
This is still a heuristic, and not mathematically rigorous. It seems to work quite well in our experiments, but requires further theoretical investigation. It is also closely related to KL-annealing in variational inference, as well as the recently-termed ‘cold posterior effect’ (Wenzel et al., 2020; Loo et al., 2021; Aitchison, 2021).
We would like to use multiple GPUs in parallel to perform large experiments quickly. Typically, we would just parallelise data, splitting up large minibatch sizes by sending different data to different GPUs. With VOGN, we can also parallelise computation over Monte-Carlo samples $\vparam_t \sim q_t(\vparam)$. Every GPU can use a different sample $\vparam_t$. This reduces variance during training, and we empirically find it leads to quicker convergence.
We introduce an external damping factor $\gamma$, added to $\vs_{t+1}$ in the denominator of Equation \eqref{eq:VOGN_mu} (Zhang et al., 2018). This increases the lower bound of the eigenvalues on the diagonal covariance $\vSigma_t$, preventing the step size and noise from becoming too large. However, this also detracts from the principled variational inference equations, and there is currently no theoretical justification for this (beyond the intuition we just provided).
Let’s recap. We derived the VOGN equations (Equations \eqref{eq:VOGN_mu} and \eqref{eq:VOGN_Sigma}) in the previous blog post. We started this post by comparing the equations to Adam, noting key similarities and differences. One key difference was based off the Gauss-Newton approximation, which slows VOGN down (per epoch) compared to SGD-based algorithms like Adam. Based on the similarities, we borrowed tricks to scale Adam to large data settings, and applied them to VOGN.
All of these tricks are important to get VOGN’s results on ImageNet. The final algorithm is summarised in Algorithm 1 below. One downside of VOGN when compared to Adam is the additional hyperparameters that require tuning. At the end of this blog post, we provide best practices for tuning these hyperparameters.
We are finally in a place to run VOGN on ImageNet and analyse results. We take Algorithm 1and run it on ImageNet.
Let’s go through these results slowly.
There are many more results in Osawa et al. (2019), such as CIFAR-10 with a variety of architectures. We tend to see a similar story, where VOGN performs comparably on validation accuracy, and well on uncertainty metrics.
Due to the Bayesian nature of VOGN, we see some interesting trade-offs (see the paper for figures).
Reducing the prior precision $\delta$ results in higher validation accuracy, but also a larger train-test gap, corresponding to more overfitting. With very small prior precisions, performance is similar to non-Bayesian methods like Adam.
Increasing the number of training MC samples ($K$ in Algorithm 1) improves VOGN’s convergence rate and stability, as it reduces gradient variance during training. But this is at the cost of increased computation. Increasing the number of MC samples during testing improves generalisation.
If you are like me, metrics such as negative-log-likelihood and expected calibration error do not mean much when it comes to analysing if your algorithm has ‘better uncertainty’. We should also test on downstream tasks to see how reliable our methods are, and more and more papers are starting to do so (see also this year’s NeurIPS Bayesian Deep Learning workshop, which makes this a priority). The VOGN paper tested on two downstream tasks: continual learning and out-of-distribution performance.
Continual Learning: I personally think continual learning is a very good way to test approximate Bayesian inference algorithms, particularly variational deep-learning algorithms. We tested VOGN on Permuted MNIST, finding it performs as well as VCL (Nguyen et al., 2018; Swaroop et al., 2019), but trained more than an order of magnitude quicker. More recently, VOGN has also achieved good results on a bigger Split CIFAR benchmark (see Section 4.5 of Eschenhagen (2019)), which VCL struggles to scale to.
Out-of-distribution performance: We also tested VOGN on standard out-of-distribution benchmarks, such as training on CIFAR-10 and testing on SVHN and LSUN. Figure 5 in the paper shows results (histograms of predictive entropy), where we qualitatively see VOGN performing well.
In the first blog post, we derived VOGN, our natural-gradient variational inference algorithm. In this blog post, we scaled it all the way to ImageNet. We made approximations along the way, but by being clever about when and where to make approximations, we have ended up with a practical algorithm that has Bayesian principles. Our final algorithm therefore performs reasonably well in downstream tasks.
It has been two years since publishing VOGN’s performance on ImageNet, and the field continues to move at break-neck pace. More algorithms and more benchmarks continue to be published, as well as more insight into VI.
Firstly, many thanks to my co-authors Kazuki Osawa, Anirudh Jain, Runa Eschenhagen, Richard Turner, Rio Yokota and Emtiyaz Khan, many of whom also provided valuable feedback on these blog posts. I would also like to thank Andrew Foong, Wessel Bruinsma and Stratis Markou for their comments during drafting of these blog posts.
As we saw in Algorithm 1, there are many hyperparameters that need tuning for VOGN (and generally for VI at a large-scale). Here we briefly summarise how we did this in Osawa et al. (2019), following the guidelines presented in Section 3 of the paper. The key idea is to make sure VOGN training closely follows Adam’s trajectory in the beginning of training.
First, we tune hyperparameters for OGN, which is the same as VOGN except setting $\vparam_t=\vmu_t$ (no MC sampling). OGN is more stable than VOGN and is a convenient stepping stone as we move from Adam to VOGN. So, we initialise OGN’s hyperparameters at Adam’s values, and tune until OGN is competitive with Adam. This requires tuning learning rates, prior precision $\delta$, and setting a suitable value for the data augmentation factor (if using data augmentation).
Then, we move to VOGN. Now, we (fine-)tune the prior precision $\delta$, warm-start the tempering parameter $\tau$ (such as increasing $\tau$ from $0.1\rightarrow1$ during the first few optimisation steps), and the number of MC samples $K$ for VOGN (more samples means more stable training, but more computation cost). We also now consider adding an external damping factor $\gamma$ if required.
Despite their successes, accurately quantifying uncertainty in the predictions of DNNs is notoriously hard, especially if there is a shift in the data distribution between train and test time. In practice, this might often lead to overconfident predictions, which is particularly harmful in safety-critical applications such as healthcare and autonomous driving. One principled approach to quantify the predictive uncertainty of a neural net is to use Bayesian inference.
The standard practice in deep learning is to estimate the parameters using just a single point found through gradient-based optimisation. In contrast, in Bayesian deep learning (check out this blog post for an introduction to Bayesian deep learning), the goal is to infer a full posterior distribution over the model’s weights. By capturing a distribution over weights, we capture a distribution over neural networks, which means that prediction essentially takes into account the predictions of (infinitely) many neural networks. Intuitively, on data points that are very distinct from the training data, these different neural nets will disagree on their predictions. This will result in high predictive uncertainty on such data points and therefore reduce overconfidence.
The problem is that modern deep neural nets are so big that even trying to approximate this posterior is highly non-trivial. We are not even talking about humongous 100-billion-parameter models like OpenAI’s GPT-3 here (Brown et al., 2020) — even for a neural net with more than just a few layers it’s hard to do good posterior inference! Therefore, it’s becoming more and more challenging to design approximate inference methods that actually scale.
To cope with this problem, many existing Bayesian deep learning methods make very strong and unrealistic approximations to the structure of the posterior. For example, the common mean field approximation approximates the posterior by a distribution which fully factorises over individual weights. Unfortunately, recent papers (Ovadia et al., 2019; Foong et al., 2019) have empirically demonstrated that such strong assumptions result in bad performance on downstream tasks such as uncertainty estimation. Can we do better than this? $\newcommand{\vy}{\mathbf{y}}$ $\newcommand{\vw}{\mathbf{w}}$ $\newcommand{\mH}{\mathbf{H}}$ $\newcommand{\mX}{\mathbf{X}}$ $\newcommand{\D}{\mathcal{D}}$ $\newcommand{\c}{\textsf{c}}$
Most existing Bayesian deep learning methods try to do inference over all the weights of the neural net. But do we actually need to estimate a posterior distribution over all weights?
It turns out that you often don’t need all those weights. In particular, recent research (Cheng et al., 2017) has shown that, since deep neural nets are so heavily overparametrised, it’s possible to find a small subnetwork within a neural net containing only a very small fraction of the weights, which, miraculously, can achieve the same accuracy as the full neural net. These subnetworks can be found by so-called pruning techniques.
As shown in Figure 1, pruning techniques typically first train the neural net, and then, after training, remove certain weights or even entire neurons according to some criterion. There has been a lot of recent interest in this research direction; for example, the best paper award at ICLR 2019 went to Jonathan Frankle and Michael Carbin’s now famous work on the lottery ticket hypothesis (Frankle and Carbin, 2018), which showed that you can even retrain the pruned network from scratch and still achieve the same accuracy as the full network.
But how does this help us? We asked ourselves the exact same question about the model uncertainty: Can a full deep neural net’s model uncertainty be well-preserved by a small subnetwork’s model uncertainty? It turns out that the answer is yes, and in the remainder of this blog post, you will learn about how we came to this conclusion.
Assume that we have divided the weights $\vw$ into two disjoint subsets: (1) the subnetwork $\vw_S$ and (2) the set of all remaining weights $\{\vw_r\}_{r \in S^\c}$. We will later describe how we select the subnetwork; for now, just assume that we have it already. We propose to approximate the posterior distribution over as follows: \begin{equation} p(\vw \cond \D) \overset{\text{(i)}}{\approx} p(\vw_S \cond \D) \prod_{r \in S^\c} \delta(\vw_r - \widehat{\vw}_r) \overset{\text{(ii)}}{\approx} q(\vw_S) \prod_{r \in S^\c} \delta(\vw_r - \widehat{\vw}_r) =: q_S(\vw). \end{equation} The first step (i) of our posterior approximation then involves a posterior distribution over just the subnetwork $\vw_S$, and delta functions over all remaining weights $\{\vw_r\}_{r \in S^\c}$. Put differently, we only treat the subnetwork $\vw_S$ in a probabilistic way, and assume that each remaining weight $\vw_r$ is deterministic and set to some fixed value $\widehat\vw_r$. Unfortunately, exact inference over the subnetwork is still intractable, so, in the second step (ii) of our approximation, we introduce an approximate posterior $q$ over the subnetwork $\vw_S$. Importantly, as the subnetwork is much smaller than the full network, this allows us to use expressive posterior approximations that would otherwise be computationally intractable (e.g. full-covariance Gaussians). That’s it.
There are a few questions that we still need to answer:
Let’s start with Q1.
In this work, we infer a full-covariance Gaussian posterior over the subnetwork using the Laplace approximation, which is a classic approximate inference technique. If you don’t recall how the Laplace approximation works, below we provide a short summary. For more details on the Laplace approximation and a review of its use in deep learning, please refer to Daxberger et al. (2021).
The Laplace approximation proceeds in two steps.
Obtain a point estimate over all model weights using maximum a-posteriori (short MAP) inference. In deep learning, this is typically done using stochastic gradient-based optimisation methods such as SGD. \begin{equation} \widehat\vw = \argmax_{\vw} \, [\log p(\vy \cond \mX, \vw) + \log p(\vw)] \end{equation}
Locally approximate the log-density of the posterior with a second-order Taylor expansion. This produces a full-covariance Gaussian posterior over the weights, where the mean of the Gaussian is simply the MAP estimate, and the covariance matrix of the Gaussian is the inverse Hessian $\mH$ of the loss with respect to the weights $\vw$ (averaged over the training data points): \begin{equation} p(\vw \cond \D) \approx q(\vw) = \Normal(\vw \cond \widehat\vw, \mH^{-1}). \end{equation}
What this essentially does is it defines a Gaussian centered at the MAP estimate, with a covariance matrix that matches the curvature of the loss at the MAP estimate, as illustrated in Figure 2.
The main advantage of the Laplace approximation, and also the reason why we use it, is that it is applied post-hoc on top of a MAP estimate and doesn’t require us to re-train the network. This is practically very appealing as MAP estimation is something we can do very well in deep neural nets. The main issue, however, is that it requires us to compute, store, and invert the full Hessian $\mH$ over all weights. This scales quadratically in space and cubically in time (in terms of the number of weights) and is therefore computationally intractable for modern neural nets.
Fortunately, in our case, we don’t actually want to do inference over all the weights, but only over a subnetwork. In this case, the second step of the Laplace approximation involves inferring a full-covariance Gaussian posterior over only the subnetwork weights $\vw_S$: \begin{equation} p(\vw_S \cond \D) \approx q(\vw_S) = \Normal(\vw_S \cond \widehat\vw_S, \mH_S^{-1}). \end{equation} This is now tractable, since the subnetwork will in practice be substantially smaller than the full network, effectively giving us quadratic gains in space complexity and cubic gains in time complexity!
In fact, this also answers Q2 of how to set the remaining weights not part of the subnetwork: Since the Laplace approximation requires us to first obtain a MAP estimate over all weights, it’s natural to simply leave all other weights at their MAP estimates!
Let’s now look at how subnetwork inference is done in practice.
Overall, our proposed subnetwork inference algorithm comprises the following four steps:
Ok, now we know how to do inference over the subnetwork, but how do we find the subnetwork in the first place?
Recall that we want to preserve as much model uncertainty as possible with our subnetwork. A natural goal is therefore to find the subnetwork whose posterior is closest to the full network posterior. That is, we want to find the subset of weights that minimises some measure of discrepancy between the posterior over the full network and the posterior over the subnetwork.
To measure this discrepancy, we choose to use the Wasserstein distance: \begin{align} &\min \text{Wass}[\ \text{exact full posterior}\ |\ \text{subnetwork posterior}\ ] \nonumber \vphantom{\prod} \newline &\qquad= \min \text{Wass}[\ p(\mathbf{w} \cond \mathcal{D})\ |\ q_S(\mathbf{w})\ ] \vphantom{\prod} \newline &\qquad\approx \min \text{Wass}[\ \mathcal{N}\left(\mathbf{w}; \widehat{\mathbf{w}}, \mathbf{H}^{-1}\right)\ |\ \mathcal{N}(\mathbf{w}_S; \widehat{\mathbf{w}}_S, \mathbf{H}_S^{-1}) \prod_{r \in S^\c} \delta(\mathbf{w}_r - \widehat{\mathbf{w}}_r )\ ]. \end{align} As the exact full network posterior $p(\mathbf{w} \cond \mathcal{D})$ is intractable, we here approximate it as a Gaussian $\mathcal{N}\left(\mathbf{w}; \widehat{\mathbf{w}}, \mathbf{H}^{-1}\right)$ over all weights (also estimated via the Laplace approximation). Also, as described earlier, the subnetwork posterior $q_S(\mathbf{w})$ is composed of a Gaussian $\mathcal{N}(\mathbf{w}_S; \widehat{\mathbf{w}}_S, \mathbf{H}_S^{-1})$ over the subnetwork and delta functions $\delta(\mathbf{w}_r - \widehat{\mathbf{w}}_r )$ over all other weights $\{\mathbf{w}_r\}_{r \in S^\c}$. Note that due to the delta functions, the subnetwork posterior is degenerate; this is why we use the Wasserstein distance, which remains well-defined for such degenerate distributions.
Unfortunately, this objective is still intractable, as it depends on all entries of the Hessian of the full network. To obtain a tractable objective, we assume that the full network posterior is factorised. By making this factorisation assumption, the Wasserstein objective now only depends on the diagonal entries of the Hessian, which are cheap to compute. I know what you’re thinking right now: “Didn’t they just tell us that the whole point of this subnetwork inference thing is to avoid making the assumption that the posterior is diagonal? And now they’re telling us that, actually, we still do have to make this assumption? This doesn’t make any sense!”
Well, in fact, it turns out that making the diagonal assumption just for the purpose of subnetwork selection, but then doing full-covariance Gaussian posterior inference over the subnetwork is much better than making the diagonal assumption for the purpose of inference itself (i.e. inference over the weights of the subnetwork and even over all weights), which we’ll see in the experiments later.
All in all, our proposed subnetwork selection procedure is as follows:
Great, we now know that a subnetwork can be found by (approximately) minimising the Wasserstein distance between the subnetwork posterior and the full network posterior. But how do we make predictions with this weird approximate posterior that is partly probabilistic and partly deterministic? We simply use all the weights of the neural net to make predictions: we integrate out the weights in the subnetwork, and just leave all other weights fixed at their MAP estimates. For integrating out the subnetwork weights, one can either use Monte Carlo or a closed-form approximation — please refer to the full paper for more details (the reference is given at the end of this blog post). Subnetwork inference therefore combines the strong predictive accuracy of the MAP estimate with the calibrated uncertainties of a Bayesian posterior.
Finally, we will now demonstrate the effectiveness of subnetwork inference in two experiments.
In the first experiment we train a small, 2-layer, fully-connected network with a total of 2600 weights on a 1D regression task, shown in Figure 7.
The number in brackets in the plot title denotes the number of weights over which we do inference; for example, for the MAP estimate (Figure 7, top left), inference was done over zero weights. As you can see, the 1D function we’re trying to fit consists of two separated clusters of data, and the goal here is to capture as much of the predictive uncertainty as possible, especially in-between those data clusters (Foong et al., 2019). As expected, the point estimate (Figure 7, top left) doesn’t capture any uncertainty, but instead makes confident predictions even in regions where there’s no data, which is bad.
On the other extreme, we can infer a full covariance Gaussian posterior over all the 2600 weights using a Laplace approximation (Figure 7, top middle), which is only tractable here due to the small model size. As we can see, a full-covariance Gaussian posterior is able to capture predictive uncertainty both at the boundaries and in-between the data clusters, so we will consider this to be the ideal, ground-truth posterior for this experiment.
Of course, in larger-scale settings, a full-covariance Gaussian would be intractable, so people often resort to diagonal approximations which assume full independence between the weights (Figure 7, top right). Unfortunately, as we can see, even though we do inference over all 2600 weights, due to the diagonal assumption we sacrifice a lot of the predictive uncertainty, especially in-between the two data clusters, where it’s only marginally better than the point estimate.
Now what about our proposed subnetwork inference method? First, let’s try doing full-covariance Gaussian inference over only 50% (that is, 1300) of the weights, found using the described Wasserstein minimisation approach (Figure 7, bottom left). As we can see, this approach captures predictive uncertainty much better than the diagonal posterior, and is even quite close to the full-covariance Gaussian over all weights. Well, but 50% is still quite a lot of weights, so let’s try to go even smaller, much smaller, to only 3% of the weights, which corresponds to just 78 weights here (Figure 7, bottom middle). Even then, we’re still much better off than the diagonal Gaussian. Let’s push this to the extreme, and estimate a full-covariance Gaussian over as little as 1% (that is, 26) of the weights (Figure 7, bottom right). Perhaps surprisingly, even with 1% of weights remaining, we do significantly better than the diagonal baseline, and are able to capture significant in-between uncertainty!
Overall, the take-away from this experiment is that doing expressive inference over a very small, but carefully chosen subnetwork, and capturing weight correlations just within that subnetwork can preserve more predictive uncertainty than a method that does inference over all the weights, but ignores weight correlations.
Ok, 1D regression is fun, but we’re of course interested in scaling this to more realistic settings. In this second experiment, we consider the task of image classification under distribution shift. This task is much more challenging than 1D regression, so the model that we use is significantly larger than before: we use a ResNet-18 model with over 11 million weights, and, to remain tractable, we do inference over as little as 42 thousand (which is only around 0.38%) of the weights, again found using Wasserstein minimisation.
We consider five baselines: the MAP estimate, a diagonal Laplace approximation over all 11M weights, Monte Carlo (MC) dropout over all weights (Gal and Ghahramani, 2015), Variational Online Gauss-Newton (short VOGN, Osawa et al., 2019), which estimates a factorised Gaussian over all weights, a Deep Ensemble (Lakshminarayanan et al., 2017) of 5 independently trained ResNet-18 models, and Stochastic Weight Averaging Gaussian (short SWAG, Maddox et al., 2019), which estimates a low-rank plus diagonal posterior over all weights. As another baseline, we also consider subnetwork inference with a randomly selected subnetwork (denoted Ours (Rand)), which will allow us to assess the impact of how the subnetwork is chosen.
We consider two benchmarks for evaluating robustness to distribution shift which were recently proposed by Ovadia et al. (2019) (Figure 8): firstly, we have rotated MNIST, where the model is trained on the standard MNIST training set, and then at test time evaluated on increasingly rotated MNIST digits (as for example shown for the digit 2 in Figure 8, top); and secondly, we consider corrupted CIFAR-10, where we again train on the standard CIFAR-10 training set, but then evaluate on corrupted CIFAR-10 images; the test set contains over a dozen different corruption types, with five levels of increasing corruption severity (in this example, the image of a dog in Figure 8, bottom, is getting more and more blurry from left to right).
Let’s start with rotated MNIST (Figure 9). On the x-axis, we have the degree of rotation, and on the y-axis, we plot two different metrics: on top, we plot the test errors achieved by the different methods (where lower values are better), and on the bottom, we plot the corresponding log-likelihood, as a measure of calibration (where higher values are better). Here, we see that MAP, diagonal Laplace, MC dropout, the deep ensemble, SWAG, and the random subnetwork baseline all perform roughly similarly in terms of calibration (Figure 9, bottom): their calibration becomes worse as we increase the degree of rotation; in contrast to that, subnetwork inference (shown in dark blue) remains much better calibrated, even at high degrees of rotation. The only competitive method here is VOGN, which slightly outperforms subnetwork inference in terms of calibration. Importantly, observe that this increase in robustness does not come at cost of accuracy (Figure 9, top): Wasserstein subnetwork inference (as well as VOGN) retain the same accuracy as the other methods.
Now let’s look at corrupted CIFAR10 (Figure 10). There, the story is somewhat similar: we plot the corruption severity on the x-axis versus the error (Figure 10, top) and log-likelihood (Figure 10, bottom) on the y-axis. Here, MAP, diagonal Laplace, MC dropout and the random subnetwork baseline are all poorly calibrated (Figure 10, bottom). VOGN, SWAG and deep ensembles are a bit better calibrated, but are still significantly outperformed by subnetwork inference (again in dark blue), even at high corruption severities. Importantly, the improved robustness of Wasserstein subnetwork inference again does not compromise accuracy (Figure 10, top). In contrast, the accuracy of VOGN suffers on this dataset.
Overall, the take-away from this experiment is that subnetwork inference is better calibrated and therefore more robust to distribution shift than state-of-the-art baselines for uncertainty estimation in deep neural nets.
To conclude, in this blog post, we described subnetwork inference, which is a Bayesian deep learning method that does expressive inference over a carefully chosen subnetwork within a neural network. We also showed some empirical results suggesting that this works better than doing crude inference over the full network. There remain clear limitations of this work that deserve more investigation in the future: The most important one is to develop better subnetwork selection strategies that avoid the potentially restrictive approximations we use (i.e. the diagonal approximation to the posterior covariance matrix).
Thanks a lot for reading this blog post! If you want to learn more about this work, please feel free to check out our full ICML 2021 paper:
Finally, we would like to thank Stratis Markou, Wessel Bruinsma and Richard Turner for many helpful comments on this blog post!
]]>In this blog post, we will outline how we combine ideas from reinforcement learning and quantum chemistry to catalyse the search for new molecules. We will explain how we can push the boundaries of the type of molecules we can build by representing the atoms directly in Cartesian coordinates. Finally, we will demonstrate how we can exploit symmetries of the design process to efficiently train a reinforcement learning agent for molecular design.
To be able to design general molecular structures, it is critical to choose the right representation. Most approaches rely on graph representations of molecules, where atoms and bonds are represented by nodes and edges, respectively. However, this is a strongly simplified model designed for the description of single organic molecules. It is unsuitable for encoding metals and molecular clusters as it lacks information about the relative position of atoms in 3D space. Further, geometric constraints on the molecule cannot be easily encoded in the design process. A more general representation closer to the physical system is one in which a molecule is described by its atoms’ positions in Cartesian coordinates. We therefore directly work in this space.
In particular, we design molecules by sequentially drawing atoms from a given bag and placing them onto a 3D canvas. The canvas $\mathcal{C}$ contains all atoms $(e, x)$ with element $e \in \{\ce{H}, \ce{C}, \ce{N}, \ce{O}\, \dots \}$ and position $x \in \mathbb{R}^3$ placed so far, whereas the bag $\mathcal{B}$ comprises atoms still to be placed. We formulate this task as a sequential decision-making problem in a Markov decision process, where the agent is rewarded for building stable molecules. At the beginning of each episode, the agent receives an initial state $s_0 = (\mathcal{C}_{0}, \mathcal{B}_0)$, e.g. $\mathcal{C}_0 = \emptyset$ and $\mathcal{B}_0 = \ce{SFO_4}$ (see Figure 1). At each timestep $t$, the agent draws an atom from the bag and places it onto the canvas through action $a_t$, yielding reward $r_t$ and transitioning the environment into state $s_{t+1}$. This process is repeated until the bag is empty.
An advantage of designing molecules in Cartesian space is that we can evaluate states in terms of quantum-mechanical properties.^{1} Here, the reward function encourages the agent to design stable molecules as measured in terms of their energy; however, linear combinations of other desirable properties (like drug-likeliness or toxicity) would also be possible. We define the reward as the negative difference in energy between the resulting molecule described by $\mathcal{C}_{t+1}$ and the sum of energies of the current molecule $\mathcal{C}_t$ and a new atom of element $e_t$, \begin{equation} r(s_t, a_t) = \left[E(\mathcal{C}_t) + E(e_t)\right] - E(\mathcal{C}_{t+1}), \end{equation} where $E(e) := E({e, [0,0,0]^T })$. We compute the energy using a fast semi-empirical quantum-chemical method. Importantly, the episodic return for building a molecule does not depend on the order in which atoms are placed.
Learning to place atoms in Cartesian coordinates requires that the agent exploits the symmetries of the molecular design process. Therefore, we need a policy $\pi(a_t \vert s_t)$ that is covariant^{2} to translation and rotation. In other words, if the canvas is rotated or translated, the position $x$ of the atom to be placed should be rotated and translated as well.
To achieve this, we first encode the current state $s_t$ into an invariant representation $s^\text{inv} = \mathsf{SchNet}(s_t)$, where $\mathsf{SchNet}$ (Schütt et al., 2017; Schütt et al., 2018) is a neural network architecture that models interactions between atoms. Given $s^\text{inv}$, our agent selects 1) a focal atom $f$ among already placed atoms that acts as a reference point, 2) an available element $e$ from the bag for the atom to be placed, and 3) the position of the atom to be placed in internal coordinates (see Figure 2). These coordinates consist of the distance $d$ to the focal atom as well as angles $\alpha$ and $\psi$ with respect to the focal atom and its neighbours. Finally, we obtain a position $x$ that features the required covariance by mapping the internal coordinates back to Cartesian coordinates. We call the resulting agent $\mathsf{Internal}$.
Equipped with a policy, we can finally design some molecules! To demonstrate how the $\mathsf{Internal}$ agent works, we separately train it on the bags $\ce{CH_3N_O}, \ce{CH_4O}$ and $\ce{C_2H_2O_2}$ using PPO (Schulman et al., 2017). Figure 3 shows that the agent is able to learn interatomic distances as well as the rules of chemical bonding from scratch. On average, the agent reaches $90\%$ of the optimal return^{3} after only $12\,000$ steps. However, from the snapshots $\enclose{circle}{2}$ and $\enclose{circle}{3}$ in Figure 3 (b) we can see that the generated structures are not quite optimal yet. While the policy has mostly learned the atomic distances, it still has to figure out the angles between atoms. After training the policy for a bit longer, at point $\enclose{circle}{4}$ we finally generate valid, stable molecules. It works!
While these results look promising, the $\mathsf{Internal}$ agent actually struggles when faced with highly symmetric structures. As shown in Figure 4, that is because the choice of angles $\alpha$ and $\psi$ used in the internal coordinates can become ambiguous in such cases. A better approach would be to directly generate a rotation-covariant orientation $\tilde{x}$ of the atom to be placed without going through these internal coordinates.
Therefore, we replace the angles $\alpha$ and $\psi$ by directly sampling the orientation from a distribution on a sphere with radius $d$ centered at the focal atom $f$ (see Figure 5). We can define such a distribution using spherical harmonics, which are essentially basis functions defined on the sphere. In particular, we are able to model any (multi-modal) distribution on the sphere by generating the right coefficients $\hat{r}$ for these basis functions. To produce the coefficients, we modify $\mathsf{Cormorant}$ (Anderson et al., 2019), a neural network architecture for predicting properties of chemical systems that works entirely in Fourier space. Finally, we can sample a rotation-covariant orientation $\tilde{x}$ from the spherical distribution defined by $\hat{r}$ using rejection sampling. We call the resulting agent $\mathsf{Covariant}$.
To verify that $\mathsf{Covariant}$ works as expected, we compare it to the previous $\mathsf{Internal}$ agent on structures with high symmetry and coordination numbers. As shown in Figure 6 (a), $\mathsf{Covariant}$ is able to build valid molecules from the bags $\ce{SOF_4}$ and $\ce{IF_5}$ within $30\,000$ to $40\,000$ steps, whereas $\mathsf{Internal}$ fails to build low-energy configurations as it cannot distinguish highly symmetric intermediates. Further results in Figure 6 (b) for $SOF_6$ and $SF_6$ show that $\mathsf{Covariant}$ is capable of building such structures. While the constructed molecules are small in size, recall that they would be unattainable with graph-based methods as they lack important 3D information.
Let’s end here for now, even though we were really just getting started. To summarise, we have presented a novel reinforcement learning formulation for 3D molecular design guided by quantum chemistry. The key insight to get it to work was to exploit the symmetries of the design process, particularly using spherical harmonics.
One aspect we didn’t show today is how flexible this framework actually is. For example, we can use it to learn across multiple bags at the same time and generalise (to some extent) to unseen bags. Of course we can also scale up to larger molecules, though not quite as large as graph-based methods yet. Finally, we can even build molecular clusters, e.g. to model solvation processes. If that sounds interesting to you, make sure to check out the full papers:
In contrast, graph-based approaches have to resort to heuristic reward functions. ↩
More precisely, only the position $x$ needs to be covariant, whereas the element $e$ has to be invariant. ↩
We estimate the optimal return by using structural optimisation techniques to obtain the optimal structure and its energy. ↩
This is the first part of a two-part blog. This first part will involve quite a lot of detailed maths: we will derive a natural-gradient variational inference (NGVI) algorithm that can run on neural networks (NNs). We will follow the appendices in Khan et al. (2018). NGVI algorithms are in contrast to stochastic gradient algorithms such as Bayes-By-Backprop (Blundell et al., 2015), which also optimises the same Bayesian VI objective function, and also in contrast to Adam and SGD, which optimise for a non-Bayesian estimate of neural network weights.
In the second part of the blog, we will work our way to large datasets/architectures such as ImageNet/ResNets, discussing additional approximations required, as well as analysing their promising results. The second part will closely follow a paper I was involved in, Osawa et al. (2019).
I hope to leave the reader with an understanding of how NGVI algorithms for NNs are derived, and some intuition for their strengths and weaknesses. I will not discuss other Bayesian neural network algorithms, nor get involved in the recent debates over what it means to be ‘Bayesian’ in deep learning! $\newcommand{\vparam}{\boldsymbol{\theta}}$ $\newcommand{\veta}{\boldsymbol{\eta}}$ $\newcommand{\vphi}{\boldsymbol{\phi}}$ $\newcommand{\vmu}{\boldsymbol{\mu}}$ $\newcommand{\vSigma}{\boldsymbol{\Sigma}}$ $\newcommand{\vm}{\mathbf{m}}$ $\newcommand{\vF}{\mathbf{F}}$ $\newcommand{\vI}{\mathbf{I}}$ $\newcommand{\vg}{\mathbf{g}}$ $\newcommand{\vH}{\mathbf{H}}$ $\newcommand{\vs}{\mathbf{s}}$ $\newcommand{\myexpect}{\mathbb{E}}$ $\newcommand{\pipe}{\,|\,}$ $\newcommand{\data}{\mathcal{D}}$ $\newcommand{\loss}{\mathcal{L}}$ $\newcommand{\gauss}{\mathcal{N}}$
If you are reading this blog, hopefully you already know about Bayesian inference and its many promises when combined with deep learning: in short, we hope to obtain reliable confidence estimates, avoid overfitting on small datasets, and deal naturally with sequential learning. But exact Bayesian inference on large models such as neural networks is intractable.
Although there are many approximate Bayesian inference algorithms, we will only focus on variational inference. Blundell et al. (2015) introduced Bayes-By-Backprop for training NNs with VI. But this has been very difficult to scale to large NNs such as ResNets: the main problem is that optimisation is restrictively slow as it requires many passes through the data.
Separately, natural-gradient update steps were introduced as a principled way of incorporating the information geometry of the distribution being optimised (Amari, 1998). By incorporating the geometry of the distribution, we expect to take gradient steps in much better directions. This should speed up gradient optimisation significantly. For a more detailed explanation, please look at the motivation in papers such as Khan & Nielsen (2018) or Martens (2020); I found figures such as Figure 1(a) from Khan & Nielsen (2018) particularly useful.
It therefore makes sense to try and apply natural-gradient updates to VI for NNs, where speed of convergence has been an issue. In this blog post, we do this while looking closely at the mathematical details. We will follow the appendices in Khan et al. (2018) (there is a slightly different derivation in Zhang et al. (2018)). I also hope that, after reading this blog, you will be able to confidently approach recent papers that use NGVI, papers which often assume some knowledge of how NGVI algorithms are derived.
This section is a very brief overview of some fundamental concepts we will need. If you understand Equations \eqref{eq:exp_fam}, \eqref{eq:ELBO}, \eqref{eq:simple_NGD} & \eqref{eq:NGD}, then feel free to skip the text. If anything is unfamiliar, there will be links to some good references.
Exponential family distributions are commonly used in machine learning, with some key properties we can use. They include Gaussian distributions, which is the specific case we will consider later. Exponential family distributions are covered in most machine learning courses, and there are many good references, such as Murphy (2021) and Bishop (2006).
An exponential family distribution over parameters $\vparam$ with natural parameters $\veta$ has the form,
\begin{align} \label{eq:exp_fam} q(\vparam|\veta) = q_{\veta}(\vparam) = h(\vparam)\exp [ \langle\veta,\vphi(\vparam)\rangle - A(\veta) ], \end{align}
where $\vphi(\vparam)$ is the vector of sufficient statistics, $\langle \cdot,\cdot \rangle$ is an inner product, $A(\veta)$ is the log-partition function and $h(\vparam)$ is a scaling constant. We also assume a minimal exponential family, when the sufficient statistics are linearly independent. This means that there is a one-to-one mapping between $\veta$ and the mean parameters $\vm = \myexpect_{q_\veta(\vparam)} [\vphi(\vparam)]$, and that $\vm = \nabla_\veta A(\veta)$.
In Bayesian inference, we want to learn the posterior distribution over parameters after observing some data $\data$. The posterior is given as,
\begin{equation} p(\vparam \cond \data) = \frac{ {\color{purple}p(\data\cond\vparam)} {\color{blue}p_0(\vparam)}}{p(\data)}, \nonumber \end{equation}
where ${\color{purple}p(\data\pipe \vparam)}$ is the data likelihood and ${\color{blue}p_0(\vparam)}$ is the prior over parameters. We will use colours to keep track of terms coming from the likelihood and prior. Note that in supervised learning, where the dataset $\data$ consists of inputs $\mathbf{X}$ and labels $\mathbf{y}$, we should write the likelihood as ${\color{purple}p(\mathbf{y}\pipe \mathbf{X}, \vparam)}$, but we slightly abuse notation by writing ${\color{purple}p(\data\pipe \vparam)}$.
If our likelihood and prior are set correctly, then exact Bayesian inference is optimal, but unfortunately there are problems in reality (this statement comes with many caveats! See e.g. this blog post for a more detailed discussion). To name two problems, (i) we are usually unsure if the likelihood or prior is correct, and (ii) exact Bayesian inference is often not possible, especially in NNs. In this blog post, we only focus on approaches to problem (ii): algorithms for approximate Bayesian inference. We do not consider problem (i).
Variational Bayesian inference approximates exact Bayesian inference by learning the parameters of a distribution $q(\vparam)$ that best approximates the true posterior distribution $p(\vparam \pipe \data)$. We do this by maximising the Evidence Lower Bound (ELBO), which is equivalent to minimising the KL divergence between the approximate distribution and the true posterior. By assuming that $q(\vparam)$ is an exponential family distribution $q_\veta(\vparam)$, we can write the ELBO as follows,
\begin{equation} \label{eq:ELBO} \loss_\mathrm{ELBO}(\veta) = \underbrace{\myexpect_{q_\veta(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]}_\text{Likelihood term} + \underbrace{\myexpect_{q_\veta(\vparam)} \left[\log \frac{ {\color{blue}p_0(\vparam)}}{q_\veta(\vparam)} \right]}_\text{KL (to prior) term}, \end{equation}
which we optimise with respect to $\veta$. There are many good references on variational inference, such as Blei et al. (2018) or Zhang et al. (2018).
Let’s say that we have some function $\loss(\veta)$ that we want to optimise with respect to the parameters of an exponential family distribution $\veta$. Later in this blog post, this function will be the ELBO. Natural-gradient updates take gradient steps as follows until convergence,
\begin{equation} \label{eq:simple_NGD} \veta_{t+1} = \veta_t + \beta_t \vF(\veta_t)^{-1}\nabla_\veta \loss(\veta_t), \end{equation}
where $\nabla_\veta \loss(\veta_t) = \nabla_\veta \loss(\veta) \pipe_{\veta=\veta_t}$,
\begin{equation*} \vF(\veta_t) = \myexpect_{q_\veta(\vparam)} \left[ \nabla_\veta \log q_\veta(\vparam) \nabla_\veta \log q_\veta(\vparam)^\top \right], \end{equation*}
is the Fisher information matrix, and $\beta_t$ is a learning rate. As previously discussed, natural-gradient methods incorporate the information geometry of the distribution being optimised (through the Fisher information matrix), and therefore reduce the number of gradient steps required. Some good references include Amari (1998) and Martens (2014).
We can use a neat trick of exponential families to simplify the update step and side-step having to compute and invert the Fisher matrix directly (see e.g. Hoffman et al. (2013) or Khan & Lin (2017)). One way to show this is to note that
\begin{equation} \label{eq:mean-natural gradient} \nabla_\veta \loss(\veta_t) = [\nabla_\veta \vm_t] \nabla_\vm \loss_*(\vm_t) = [\nabla^2_{\veta\veta}A(\veta_t)] \nabla_\vm \loss_*(\vm_t) = \vF(\veta_t) \nabla_\vm \loss_*(\vm_t), \end{equation}
where $\loss_*(\vm)$ is the same function as $\loss(\veta)$ except written in terms of the mean parameters $\vm$. We have used the fact that $\vF(\veta) = \nabla^2_{\veta\veta}A(\veta)$, please see earlier references for this.
Plugging this in, we get our simplified natural-gradient update step,
\begin{equation} \label{eq:NGD} \veta_{t+1} = \veta_t + \beta_t \nabla_\vm \loss_*(\vm_t). \end{equation}
We wish to combine variational inference with natural-gradient updates, so let’s get straight into it: we plug the ELBO (Equation \eqref{eq:ELBO}) into the natural-gradient update (Equation \eqref{eq:NGD}). Let the prior ${\color{blue}p_0(\vparam)}$ be an exponential family with natural parameters ${\color{blue}\veta_0}$. We first note that the KL term in the ELBO can be simplified as we are dealing with exponential families,
\begin{align} \nabla_\vm \,\text{KL term} &= \nabla_\mathbf{m} \mathbb{E}_{q_\eta(\boldsymbol{\theta})} \left[ \boldsymbol{\phi}(\boldsymbol{\theta})^\top ({\color{blue}\veta_0} - \veta) + A(\veta) + \text{const} \right] \nonumber\newline &= \nabla_\mathbf{m} \left[ \mathbf{m}^\top ({\color{blue}\veta_0} - \veta) \right] + \nabla_\mathbf{m} A(\veta) \nonumber\newline &= {\color{blue}\veta_0} - \veta - \left[ \nabla_\mathbf{m}\veta \right]^\top \mathbf{m} + \nabla_\mathbf{m} A(\veta) \nonumber\newline &= {\color{blue}\veta_0} - \veta - \mathbf{F}(\veta)^{-1}\mathbf{m} + \mathbf{F}(\veta)^{-1}\mathbf{m} \nonumber\newline &= {\color{blue}\veta_0} - \veta. \nonumber \end{align}
The third line follows using the product rule, and the fourth line uses $\nabla_\mathbf{m}(\cdot) = \mathbf{F}(\veta)^{-1} \nabla_\veta(\cdot)$ from Equation \eqref{eq:mean-natural gradient} and the symmetry of the Fisher information matrix. Plugging the ELBO (with this simplification) into Equation \eqref{eq:NGD},
\begin{align} \veta_{t+1} &= \veta_t + \beta_t \left( \nabla_\vm \myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right] + ({\color{blue}\veta_0} - \veta_t) \right) \nonumber\newline \label{eq:BLR} \therefore \veta_{t+1} &= (1-\beta_t) \veta_t + \beta_t \Big({\color{blue}\veta_0} + \nabla_\vm \underbrace{\myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]}_{ {\color{purple}\Large\mathcal{F}_t}} \Big). \end{align}
This equation is presented and analysed in detail in Khan & Rue (2021), where it is called the ‘Bayesian learning rule’. I recommend reading the paper if you are interested in this and related topics: they show how this equation appears in many different scenarios (beyond just the Bayesian derivation presented above), and also consider extensions beyond what we consider in this blog post. This allows them to connect to a plethora of different learning algorithms ranging from Newton’s method to Kalman filters to Adam.
We now consider a Gaussian approximating family, $q_\veta(\vparam) = \gauss(\vparam; \vmu, \vSigma)$. We will substitute the Gaussian’s natural parameters into Equation \eqref{eq:BLR} to obtain updates for $\vmu$ and $\vSigma$. The minimal representation for a Gaussian family has two components to its natural parameters and mean parameters,
\begin{align} \veta^{(1)} &= \vSigma^{-1}\vmu, & \veta^{(2)} &= -\frac{1}{2}\vSigma^{-1}, \nonumber\newline \vm^{(1)} &= \vmu, & \vm^{(2)} &= \vmu\vmu^\top + \vSigma. \nonumber \end{align}
Let the prior be a zero-mean Gaussian, ${\color{blue}p_0(\vparam) = \gauss(\vparam; \boldsymbol{0}, \delta^{-1}\vI)}$. We can therefore write the prior natural parameters as ${\color{blue}\veta_0^{(1)} = \boldsymbol{0}, \veta_0^{(2)} = -\frac{1}{2}\delta\vI}.$
We now simplify $\nabla_\vm {\color{purple}\mathcal{F}_t}$ to be in terms of $\vmu$ and $\vSigma$ instead of $\vm$. We can use the chain rule to do this (see e.g. Opper & Archambeau (2009) or Appendix B.1 in Khan & Lin, 2017),
\begin{align} \nabla_{\vm^{(1)}}{\color{purple}\mathcal{F}_t} &= \nabla_\vmu {\color{purple}\mathcal{F}_t} - 2[\nabla_\vSigma {\color{purple}\mathcal{F}_t}] \vmu, \nonumber\newline \nabla_{\vm^{(2)}}{\color{purple}\mathcal{F}_t} &= \nabla_\vSigma {\color{purple}\mathcal{F}_t}. \nonumber \end{align}
We would like to write out the natural-gradient updates (Equation \eqref{eq:BLR}) for the parameters of a Gaussian, with the resulting equations in terms of the prior natural parameters ${\color{blue}\veta_0}$ and the data ${\color{purple}\mathcal{F}_t}$. So let’s substitute the above derivations into Equation \eqref{eq:BLR}. We start with the second element, $\veta^{(2)}$, giving us an update for $\vSigma^{-1}$,
\begin{equation} \label{eq:Gaussian_Sigma} \vSigma_{t+1}^{-1} = (1-\beta_t)\vSigma_t^{-1} + \beta_t ({\color{blue}\delta\vI} - 2\nabla_\vSigma {\color{purple}\mathcal{F}_t}). \end{equation}
We also obtain an update for the mean $\vmu$ by looking at the first element $\veta^{(1)}$,
\begin{align} \vSigma_{t+1}^{-1} \vmu_{t+1} &= (1-\beta_t)\vSigma_{t}^{-1} \vmu_{t} + \beta_t ({\color{blue}\boldsymbol{0}} + \nabla_\vmu {\color{purple}\mathcal{F}_t} - 2 [\nabla_\vSigma {\color{purple}\mathcal{F}_t}] \vmu_t) \nonumber\newline &= \underbrace{\left[ (1-\beta_t)\vSigma_t^{-1} + \beta_t ({\color{blue}\delta\vI} - 2\nabla_\vSigma {\color{purple}\mathcal{F}_t}) \right]}_{=\vSigma_{t+1}^{-1}\text{, by Equation \eqref{eq:Gaussian_Sigma}}} \vmu_t + \beta_t (\nabla_\vmu {\color{purple}\mathcal{F}_t} - {\color{blue}\delta} \vmu_t) \nonumber\newline \label{eq:Gaussian_mu} \therefore \vmu_{t+1} &= \vmu_t + \beta_t \vSigma_{t+1} (\nabla_\vmu {\color{purple}\mathcal{F}_t} - {\color{blue}\delta} \vmu_t). \end{align}
The update for the precision $\vSigma^{-1}$ (Equation \eqref{eq:Gaussian_Sigma}) is a moving average update, and the precision slowly gets closer to and tracks $({\color{blue}\delta\vI} - 2\nabla_\vSigma {\color{purple}\mathcal{F}_t})$. The update for the mean $\vmu$ (Equation \eqref{eq:Gaussian_mu}) is very similar to an update for a (stochastic) gradient update for the mean. A key difference is the additional $\vSigma_{t+1}$ term, which (loosely speaking!) is the ‘natural-gradient’ part of the update: it automatically determines the learning rate for different elements of the mean $\vmu$. In the second part of the blog post, we will see how this relates to other algorithms such as Adam, which also try and automatically determine learning rates using data.
We are now very close to a complete NGVI algorithm. We just need to deal with the $\nabla_\vmu {\color{purple}\mathcal{F}_t}$ and $\nabla_\vSigma {\color{purple}\mathcal{F}_t}$ terms. Fortunately, we can express these in terms of the gradient and Hessian of the negative log-likelihood by invoking Bonnet’s and Price’s theorems (Opper & Archambeau, 2009; Rezende et al., 2014):
\begin{align} \label{eq:bonnet_gradient} \nabla_\vmu {\color{purple}\mathcal{F}_t} &= \nabla_\vmu \myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]& &= \myexpect_{q_{\veta_t}(\vparam)} \left[\nabla_\vparam \log {\color{purple}p(\data\pipe\vparam)}\right]& &= -\myexpect_{q_{\veta_t}(\vparam)} \left[N{\color{purple}\vg(\vparam)} \right], \newline \label{eq:bonnet_hessian} \nabla_\vSigma {\color{purple}\mathcal{F}_t} &= \nabla_\vSigma \myexpect_{q_{\veta_t}(\vparam)} \left[\log {\color{purple}p(\data\pipe\vparam)}\right]& &= \frac{1}{2}\myexpect_{q_{\veta_t}(\vparam)} \left[\nabla^2_{\vparam\vparam} \log {\color{purple}p(\data\pipe\vparam)}\right]& &= -\frac{1}{2}\myexpect_{q_{\veta_t}(\vparam)} \left[N{\color{purple}\vH(\vparam)} \right], \end{align}
where we have used the average per-example gradient ${\color{purple}\vg(\vparam)}$ and Hessian ${\color{purple}\vH(\vparam)}$ of the negative log-likelihood (the dataset has $N$ examples).
One final step. Until now we have been exact in our derivations (given a VI objective and Gaussian approximating family). But we now need to make our first approximation to estimate Equations \eqref{eq:bonnet_gradient} and \eqref{eq:bonnet_hessian}: we use a Monte-Carlo sample $\vparam_t \sim q_{\veta_t}(\vparam) = \gauss(\vparam; \vmu_t, \vSigma_t)$ to approximate the expectation terms. We expect any approximation error to reduce as we increase the number of Monte-Carlo samples.
This leads to an algorithm called Variational Online-Newton (VON) in Khan et al. (2018),
\begin{align} \label{eq:VON_Sigma} \hspace{1em}\vSigma_{t+1}^{-1} &= (1-\beta_t)\vSigma_t^{-1} + \beta_t (N {\color{purple}\vH(\vparam_t)} + {\color{blue}\delta\vI}) \newline \label{eq:VON_mu} \vmu_{t+1} &= \vmu_t - \beta_t \vSigma_{t+1} (N{\color{purple}\vg(\vparam_t)} + {\color{blue}\delta}\vmu_t). \end{align}
We can run this algorithm on models where we can calculate the gradient and Hessian (such as by using automatic differentiation). But calculating Hessians of (non-toy) neural networks is still difficult. We therefore have to approximate the Hessian ${\color{purple}\vH(\vparam_t)}$ in some way. This is done in the algorithm VOGN.
The Gauss-Newton matrix (Martens, 2014; Graves, 2011; Schraudolph, 2002) approximates the Hessian with first order information, ${\color{purple}\vH(\vparam_t)} = -\nabla^2_{\vparam\vparam} \log {\color{purple}p(\data \pipe \vparam)} \approx \frac{1}{N} \sum_{i\in\data} {\color{purple}\vg_i(\vparam_t) \vg_i(\vparam_t)^\top} $. It has some nice properties such as being positive semi-definite (which we require), making it a suitable choice. We expect it to become a better approximation to the Hessian as we train for longer. As we will see in the second part of this blog, it can also be calculated relatively quickly. Please see the references for more details on the benefits and disadvantages of using the Gauss-Newton matrix, such as how it is connected to the (empirical) Fisher Information Matrix.
This Gauss-Newton approximation is the key approximation to go from VON to VOGN. But we also make some other approximations to allow for good scaling to large datasets/architectures. Here is a full list of changes to go from VON to VOGN with some comments:
These changes lead to our final VOGN algorithm, ready for running on neural networks,
\begin{align} \label{eq:VOGN_mu} \vmu_{t+1} &= \vmu_t - \alpha_t \frac{ {\color{purple}\hat{\vg}(\vparam_t)} + {\color{blue}\tilde{\delta}}\vmu_t}{\vs_{t+1} + {\color{blue}\tilde{\delta}}}, \newline \label{eq:VOGN_Sigma} \vs_{t+1} &= (1-\beta_t)\vs_t + \beta_t \frac{1}{M} \sum_{i\in\mathcal{M}_t}\left( {\color{purple}\vg_i(\vparam_t)}^2 \right), \end{align}
where ${\color{blue}\tilde{\delta}} = {\color{blue}\delta}/N$, and all operations are element-wise.
So how does this perform in practice? We explore this in detail in the second part of the blog. For now, we borrow Figure 1(b) from Khan & Nielsen (2018), which shows how Natural-Gradient VI (VOGN) can converge much quicker than Gradient VI (implemented as Bayes-By-Backprop (Blundell et al., 2015)) on two relatively small datasets.
I hope you now understand NGVI for BNNs a little better. You have seen how the equations are derived, and hopefully have more of a feel for why and when they might work. There was a lot of detailed maths, but I have tried to provide some intuition and make all our approximations clear.
In this first part, we stopped at VOGN on small neural networks. In the second part, we will compare VOGN with stochastic gradient-based algorithms such as SGD and Adam to provide some further intuition. We will take some inspiration from Adam to scale VOGN to much larger datasets/architectures, such as ImageNet/ResNets. The next blog post will be a lot less mathematical!
If you would like to cite this blog post, you can use the following bibtex:
]]>@misc{swaroop_2021, title={Natural-Gradient Variational Inference}, url={https://mlg-blog.com/2021/04/13/ngvi-bnns-part-1.html}, author={Swaroop, Siddharth}, year={2021}}
The theory of probabilities is at bottom nothing but common sense reduced to calculus; it enables us to appreciate with exactness that which accurate minds feel with a sort of instinct for which ofttimes they are unable to account.
— Pierre-Simon Laplace (1749–1827)
The Cambridge Machine Learning Group is renowned for having drunk the Bayesian Kool-Aid. We evangelise about probabilistic approaches in our teaching and research as a principled and unifying view of machine learning and statistics. In this context, I (Rich) have found it striking that many of the most influential contributors to this brand of machine learning and statistics are more circumspect. From Michael Jordan remarking that “this place is far too Bayesian”,^{1} to Geoff Hinton saying that, having listed the key properties of the problems he’s interested in, Bayesian approaches just don’t cut it. Even Zoubin Ghahramani confides that aspects of the Bayesian approach “keep him awake at night”.^{2}
So, as we kick-off this new blog, we thought we’d dig into these concerns and attempt to burst the Bayesian bubble. In order to do this, we’ve written two posts. In this first part, we present some of the standard sunny arguments for probabilistic inference and decision making. In the second part, we’ll shoot some of these arguments down and face the demons lurking in the night.
The goal of an inference problem is to estimate an unknown quantity $X$ from known quantities $D$. Examples include inferring the mass of the Higgs boson ($X$) from collider data ($D$); estimating the prevalence of Covid 19 infections ($X$) from PCR test data ($D$); or reconstructing files ($X$) from corrupted versions stored on a damaged hard disk ($D$).
The probabilistic approach to solving such problems proceeds in three stages:
Stage 1: probabilistic modelling. The first stage is called probabilistic modelling and involves designing a probabilistic recipe that describes how all the variables, known variables $D$ and unknown variables $X$, are assumed to be produced. The model specifies a joint distribution over all these variables $p(X, D)$ and samples from it should reflect typical settings of the variables that you might have expected to encounter before seeing any data.
Stage 2: probabilistic inference. The second stage is called probabilistic inference. In this stage, the sum and product rules of probability are used to manipulate the joint distribution over all variables into the conditional distribution of the unknown variables given the known variables:
\begin{equation} \label{eq:posterior} p( X \cond D ) = \frac{p( X, D)}{p( D )}. \end{equation}
This distribution on the left hand side of this equation is known as the posterior distribution. It tells us how probable any setting of the unknown variables X is given the known variables $D$. In this way, it tells us not only what is the most likely setting of the unknown variables ($\hat X_{\text{MAP}} = \argmax_{X} p(X \cond D)$), but also our uncertainty about $X$ — it summarises our belief about $X$ after seeing $D$. Equation \eqref{eq:posterior} follows from the product rule of probability; another application of the product rule leads to Bayes’ rule.
In the example of inferring the mass of the Higgs boson, the unknown X is a parameter.^{3} Inferring parameters using the sum and product rules to form the posterior distribution is called being Bayesian.
Stage 3: Bayesian decision theory. In real-world problems, inferences are usually made to serve decision making. For example, inferences could inform the design of a new particle accelerator to pin down particle masses; could decide whether to implement another national lockdown; or could decide whether to prosecute someone for financial crimes based on data recovered from a hard drive.
The probabilistic approach supports decision making in a third stage which goes by the grand name of Bayesian decision theory. Here the user provides a loss function $L(X, \delta)$ which specifies how unhappy they would be if they take decision $\delta$ when the unknown variables take a value $X$. We can compute how unhappy they will be on average with any decision by averaging the loss function over all possible settings of the unobserved variables, weighted by how probable they are under the posterior distribution $p( X \cond D )$:
\begin{equation} (\text{average unhappiness})(\delta) = \E_{p(X \cond D)}[L(X, \delta)]. \end{equation}
This quantity is called the posterior expected loss^{4}. We now pick the decision that we expect to be least unhappy about by minimising our expected unhappiness with respect to our decision $\delta$, and we’re done!
Summary. This framework proposes a cleanly separated sequential three-step procedure: first, articulate your assumptions about the data via a probabilistic model; second, compute the posterior distribution over unknown variables; and third, select the decision which minimises the average loss under the posterior.
Why does a Bayesian represent their beliefs with probabilities^{5}, reason according to the sum and product rule, and select actions which minimise the posterior expected loss? We’ll now review the most common theoretical arguments.
(1) de Finetti’s exchangeability theorem justifies the use of model parameters $\theta$, conditional distributions $p(D_n \cond \theta)$ over data given parameters (also called the likelihood of parameters), and critically prior distributions over parameters $p(\theta)$ when specifying probabilistic models. The theorem says that if you believe the order in which the data $D = (D_n)_{n=1}^N$ arrives is unimportant — $p(D)$ is invariant to the order of $(D_n)_{n=1}^N$, an idea called exchangeability — then there exists a random variable $\theta$ with associated prior distribution such that the data $D$ are i.i.d. given $\theta$ and your belief is recovered by marginalising over $\theta$:
\begin{equation} p(D) = \int p(\theta) \prod_{n=1}^N p(D_n \cond \theta) \,\mathrm{d}\theta. \end{equation}
The argument presupposes the use of probability distributions over data, but shows that this in combination with exchangeability entails the existence of parameters with associated prior and posterior distributions. This idea was important historically as there were schools of statistical thought that eschewed placing distributions over parameters, but were happy placing distributions over data.
(2) Cox’s theorem and coherence. Cox’s theorem (Cox, 1945; Jaynes, 2003) justifies the use of a probabilistic model and application of the sum and the product rules to perform inference. The theorem starts out by listing a number of desiderata that any reasonable system of quantitative rules for inference should satisfy. One very important such desideratum is consistency or coherence^{6}: if there are multiple ways of arriving at an inference, they should all give the same answer. For example, updating our beliefs about an unknown variable X after observing data $D=(D_n)_{n=1}^N$ should give the same result as incrementally updating our beliefs about $X$ one data point $D_n$ at a time. Cox’s theorem is the conclusion that every system satisfying the desiderata must be probability theory. In particular, it identifies probability theory as the unique extension of propositional logic, where a proposition is either true or false, to varying degrees of plausibility.
(3) The Dutch book argument (Ramsey, 1926; de Finetti, 1931) is another argument for the optimality of modelling and inference, one which connects to decision making. The argument goes as follows: if you’re willing to take bets on propositions with certain odds (in a way, these odds represent your beliefs), then, unless these odds are consistent with probability theory, you’re willing to take a collection of bets that nets a sure loss (a Dutch book).
(4) Savage’s theorem (Savage, 1945) is used to argue that optimal decision making entails the three-step probabilistic approach. Like Cox’s theorem, it takes an axiomatic approach, but here the axioms relate to decisions rather than inferences. The axioms include the idea that a decision maker is characterised by the ability to rank all decisions in some order of preference. It then lists properties that any reasonable ranking should have.^{7} Savage’s theorem says that these properties entail that the decision maker’s ranking is consistent with them acting according to Bayesian decision theory (Karni, 2005).
The above arguments justify large parts of the probabilistic approach to inference and decision making. In contrast, the arguments we turn to next are only concerned with specific properties. One way to view them is as unit tests that the Bayesian framework passes.
Before we turn to them, you may have noticed that the formulation of the probabilistic approach and the justifications made so far do not include a notion of the true model or true parameters; rather, all that matters is your own personal beliefs about the world, how you update them as data arrive, and how decisions are made. However, it is perfectly reasonable to ask questions like “How does the posterior over parameters behave when the data were generated using some true underlying parameter value?” or “If I have the right model and apply probabilistic inference, will my predictions be good?” The next two results step out of the Bayesian framework to answer these questions:
(5) Doob’s consistency theorem (Doob, 1949) shows that Bayesian inference is consistent: very often, if the data were sampled from $p(D \cond X)$ for some true value for $X$, then, as the user collects more and more data, the posterior over the unknown variables $p(X \cond D)$ concentrates on this true value for $X$. This is a frequentist analysis of a Bayesian estimation procedure, and it shows that the two paradigms can live comfortably side by side: Bayesian methods provide estimation procedures; frequentist tools allow analysis of these procedures. This is one reason why Michael Jordan thought that any Bayesian-focussed research group worth its salt should be paying close attention to frequentist ideas.
(6) Optimality of Bayesian predictions. In many applications, such as those typically encountered in machine learning, predicting future data points is of central focus. What guarantees do we have on the quality of such estimates arising from the probabilistic approach? Well, if we use the Kullback–Leibler (KL) divergence to measure the distance between the true density $p(X \cond \theta)$ over the unknown $X$ and any data-dependent estimate of this density $q(X \cond D)$, then, when averaged over potential settings of the parameters and associated observed data $p(\theta)p(D\cond\theta)$, the estimate $q^*$ that minimises the average divergence is the Bayesian posterior predictive (Aitchison, 1975):
\begin{equation} q^*(X \cond D) = \argmin_{q} \E_{p(\theta,D)}[ \operatorname{KL}( p(X \cond \theta) \| q(X \cond D)) ] = p(X \cond D). \end{equation}
This argument tells us that the Bayesian predictions coming from the “right model” are KL-optimal on average. Interestingly, this result connects to recent ideas in meta-learning (Gordon et al., 2019).
(7) Wald’s theorem (Wald, 1949) can be used to justify minimising an expected loss as a way of decision making. The theorem is concerned with admissible decision rules, which are rules that, for every other decision rule, achieve a better loss for at least some realisation of the unknown $X$. This condition is a very low bar: we’d hope that any reasonable decision rule would have this property. However, surprisingly, Wald’s theorem says that the only rules which are admissible are essentially those derived from minimising the expected loss under some distribution (Wald, 1949; Lehmann & Casella, 1998).
It is striking that a number of arguments based on a diversity of desirable properties — including coherence, optimal betting strategies, specifying sensible preferences over actions, frequentist guarantees like consistency and optimal predictive accuracy — all suggest that the probabilistic approach to inference is a reasonable one. But in the next post we’ll ask whether, in the dead of night, everything is as rosy as it seems in daylight.
Actually he was referring to the Gatsby unit in 2004 — in many ways the mother of the Cambridge Machine Learning Group — and his comment was a fair one. ↩
This was in the Approximate Inference Workshop at NeurIPS in 2017. ↩
Parameters are distinguished from variables by asking what happens as we see more data: variables get more numerous, parameters do not. ↩
We previously incorrectly called the posterior expected loss the Bayes risk. Thanks to Corey Yanofsky for pointing out the mistake. ↩
That probabilities represent degrees of belief is only one interpretation of probability. For example, in Probability, Statistics, and Truth (Von Mises, 1928), von Mises argues that probability concerns limiting frequencies of repeating events. In this view and contrary to the Bayesian view, it is meaningless to talk about the probability of a one-off event: it is not possible to repeatedly sample that one-off event, which means that it doesn’t have a limiting frequency. ↩
Consistency also has another specific technical meaning in statistics, so we will use the term coherence in what follows. ↩
The axioms are reminiscent of those used in Arrow’s impossibility theorem concerning fair voting systems. ↩
The theory of subjective probability describes ideally consistent behaviour and ought not, therefore, be taken too literally.
— Leonard Jimmie Savage (1917–1971)
In the first post in this series, we laid out the standard arguments that we and many others have used to support the edifice which is Bayesian inference. In this post, we identify the weaknesses in these arguments that cause us to lose sleep at night.
We’ve split these weaknesses into three types: first, weaknesses in the standard mathematical justifications for the probabilistic approach; second, weaknesses arising in practice from the modelling stage; and, third, weaknesses arising from realising the inference stage due to computational constraints.
We have seen that probability theory and Bayesian decision theory are usually justified in one of several ways, but do these standard justifications really stand up to scrutiny? Let’s go through each of the seven arguments in turn.
(1) de Finetti’s exchangeability theorem presupposes a probability distribution over the data and shows this naturally leads to distributions over parameters. However, it does not justify the use of probability in the first place.
(2) Cox’s theorem does justify the probabilistic approach, but making the argument watertight turns out to be far more delicate than the textbooks, say of Jaynes (2003) or Bishop (2006), would have you believe. To make the theorem mathematically rigorous requires additional technical assumptions that muddy the clarity of the argument. Paris (1994, p. 24): “[W]hen an attempt is made to fill in all the details some of the attractiveness of the original is lost.”^{1} Moreover, there remains disagreement about the desirability of several of Cox’s theorem’s other assumptions.^{2} Perhaps the most controversial assumption is that plausibilities are represented by real numbers. As a consequence, for every two possible propositions, one of the two propositions conclusively has higher (or equal) belief. (This is called universal comparability.) But what if we are truly ignorant about two matters, or have not yet formed an opinion? In that case, is it reasonable to require that the plausibilities assigned to the matters are necessarily comparable?^{3}
(3) The Dutch book argument tells us that any coherent actor should use probability to express their beliefs, but the argument leaves the door open as to how the actor should update their beliefs in light of new evidence (Hacking, 1976). This is because the standard Dutch book setup is static: it does not involve a step where beliefs are updated on the basis of new information. That Bayes’ rule should be used for this purpose requires additional assumptions. Dynamic alternatives of the Dutch book argument attempt to fix this flaw, but again the force of the argument is diminished and open to criticism (Skyrms, 1987).^{4}
(4) Savage’s theorem guarantees a nearly unique loss function (unique up to an affine transform) and a unique probability, which together form the expected loss. The troubling aspect of Savage’s theorem is that the constructed loss function depends only on the outcome of a decision, and Savage’s axioms imply that outcomes of decisions have a value which is independent of the state of the world (Karni, 2005). As Wakker & Zank (1998) remark, this disentanglement can be undesirable. They give the example of health insurance, where the value of money depends on sickness or health. If the loss function is allowed to depend on other aspects of the world, then the probability constructed by the theorem is no longer unique (Karni, 2005). For this reason, Karni (2005) argues that the probability constructed by the theorem is arbitrary and thus cannot realistically represent the decision maker’s beliefs.
(5) Doob’s consistency theorem, the (6) optimality of Bayesian predictions and (7) Wald’s theorem are all only unit tests: how reassured should we really be that Bayesian inference passes them? It is not clear that the guarantees on the optimality of Bayesian predictions are the guarantees you care about. For example, typically we’re faced with a single dataset and care only about performance on it alone, rather than the average performance across many potential datasets. Similarly, the admissibility of a decision rule in Wald’s theorem is desirable but not sufficient: indeed, there are admissible estimators that are not reasonable.^{5} Doob’s consistency theorem also suffers from subtle theoretical issues (Diaconis & Freedman, 1986), which we discuss below in the context of model mismatch.
Take away. It is clear then, that uniquely and precisely justifying the three stage Bayesian approach via a single argument is much more delicate than many would have you believe. However, these issues don’t trouble our sleep. That is partly due to the fact that it is reassuring that so many and so diverse a set of arguments suggest that it is a sensible approach.^{6} However, it is also because there are far bigger issues to worry about.
All practitioners of Bayesian inference will know well that beliefs are typically only roughly encoded into probabilistic models. There are at least three good reasons for this. First, it is often hard to really pin down precisely what you believe: What tail behaviour should a variable have? Are there latent variables at play? Et cetera. Second, even when you do have an ideal model in mind, it can be mathematically challenging to accurately translate that into a probability distribution. Third, many modelling choices are often based on convenience such as mathematical tractability.
Should we be worried about this? Surely roughly encoding our prior knowledge is sufficient?
Unfortunately, seemingly small or irrelevant inaccuracies in the model can greatly affect the posterior and therefore downstream decision making. For example, Diaconis & Freedman (1986) show that “in high-dimensional problems, arbitrary details of the prior really matter”. Similarly, Kass & Raftery (1995) conclude that, in the context of Bayesian model comparison, “[t]he chief limits of Bayes factors are their sensitivity to the assumptions in the parametric model and the choice of priors.” Indeed, for this reason, the textbook presentation of model comparison, such as that in David MacKay’s excellent text book, should be seen as only a pedagogical depiction of Bayesian inference at work, rather than an approach that will bear practical fruit: discrete Bayesian model comparison does not work in practice.^{7} As a more recent example of the importance of priors in high-dimensional settings, experiments by Wenzel et al. (2020) suggest that the usual choice of Gaussian prior^{8} for Bayesian neural network models contributes to the cold posterior effect in which the Bayesian posterior is outperformed on prediction tasks by strongly tempered versions.
Can theory about the performance of the Bayesian approach in the face of model misspecification act as a comfort blanket? There are theorems, analogous to Doob’s consistency theorem in the well-specified case, that describe situations when the Bayesian posterior will still concentrate on the true parameter value even when the model is misspecified (Kleijn & van der Vaart, 2006; De Blasi & Walker, 2013; Ramamoorthi et al., 2015), but, as Grünwald & van Ommen (2017) point out^{9}, “[these theorems hold] under regularity conditions that are substantially stronger than those needed for consistency when the model is correct.”^{10}
Theoretical understanding of the consequences of model misspecification is not yet available to rescue us. The probabilistic approach to modelling therefore poses an unsettling dichotomy: on the one hand, you’re free to choose the prior, because it is your belief, your prior; but on the other hand, you should choose your prior absolutely right, because seemingly small or irrelevant changes can greatly affect the conclusions.
A way forward? This weakness of the Bayesian approach used to keep us awake at night. However, taking a slightly different perspective cures our insomnia. The conventional view dogmatically insists that Bayesians should initially and immutably encapsulate prior assumptions up front in the modelling step before seeing data. They then perform inference and then decision making when the data arrives and at that point they’re done. The alternative perspective instead uses the three stage process as a consistency checking device: if I made these assumptions, then these would be the corresponding coherent statistical inferences. In this way, we are free to explore a range of different assumptions, assessing their consequences both for the model and inferences we make. We are free to modify the model accordingly until we arrive at something that well approximates what we believe. This view has been called the hypothetico–deductive view of Bayesian inference (Gelman & Shalizi, 2011) and it has the advantage that this is the way many of us use these methods in practice. The disadvantage is that double-dipping is possible, which requires that you are careful about the checks and diagnostics that you perform on the model and corresponding inferences.
The weakness that kept Zoubin Ghahramani (and many other Bayesians) awake at night is that the Bayesian posterior is computationally intractable in all but the simplest cases. Consequently, approximations are necessary, which means that all the theoretical guarantees and justifications that are true for the exact Bayesian posterior no longer hold. Worse still, we do not know in what ways approximation will affect different aspects of the solution.
For example, one common approximation technique is variational inference (Wainwright & Jordan, 2008), which side-steps intractable computations arising from application of the sum and product rules by projecting distributions onto simpler tractable alternatives. Variational inference is therefore guaranteed to be incoherent: it returns different solutions from exactly applying the sum and product rules. How do these approximate inferences differ from the true ones?
Variational inference is known to (1) underestimate posterior uncertainty if factorised approximations are used^{11} and (2) bias parameter learning so that overly simple models are returned that underfit the data (Turner & Sahani, 2011). An example of the first phenomenon is the observation that the mean-field variational approximation in neural networks is unable to model in-between uncertainty (Foong et al., 2019). Examples of the second phenomenon are over-pruning in mixture models (MacKay, 2001; Blei & Jordan, 2006), variational autoencoders^{12} (Burda et al., 2015; Chen et al., 2016; Zhao et al., 2017; Yeung et al., 2017), and Bayesian neural networks (Trippe & Turner, 2018).
Inevitably, these errors are amplified if repeated approximation steps are required. For example, in online or continual learning the goal is to incorporate new observations sequentially without forgetting old ones (Nguyen et al., 2017). In theory, repeated application of Bayes’ rule provides the optimal solution, but in practice approximations like variational inference are required at each step. This causes amplification of the approximation error, which eventually leads the model to “forget” old data. Consequently, Bayesian methods are approximate and have to be checked and benchmarked for catastrophic forgetting, just like their non-Bayesian counterparts.
We lack a general theory that justifies variational inference. Recent work has shown that it is frequentist consistent (Wang & Blei, 2017) under assumptions similar to those required for consistency of Bayesian inference under model misspecification. Although useful, this is a long way short of a general compelling justification for the approach.^{13}
Similar arguments can be made against other approaches to approximate inference such as Monte Carlo methods. For example, Markov chain Monte Carlo (MCMC) is guaranteed to eventually give you a close-enough answer; but it may take an unfeasibly long time to do so and diagnosing whether you have reached that point is very difficult (we often rely on heuristics to check whether the chain has been run for long enough). So, although excellent software for performing MCMC exists,^{14} ensuring that it is performing correctly is delicate.^{15}
Let us take a step back and consider approximate inference in the context of Bayesian decision theory. In its ideal form, Bayesian decision theory is a cleanly separated three-step procedure consisting of model specification, inference and loss minimisation. However, in practice the use of approximate inference entangles the three steps: models which are simpler result in more accurate, and therefore more coherent, inference; and downstream decision making dictates what aspects of the true posterior your approximate posterior should capture and therefore feeds into the inference stage rather than being decoupled from it (Lacoste-Julien et al., 2011).^{16}
A way forward? To our minds, the jury is still out here about the best way forward. Hopefully a more general theory — potentially based on decision making under uncertainty and computational constraints that extends Cox’s, Ramsey’s and de Finetti’s, or Savage’s ideas — will emerge. In the meantime, continuing to provide more limited theoretical characterisation of the properties of existing inference approaches is vital. Practitioners should also test the hell out of their inference schemes to gain confidence in them. Testing on special cases where ground truth posteriors or underlying variables are known is helpful. The acid test is whether your inference scheme works on the real world data you care about, so test cases also need to replicate aspects of this situation. Here the ideas of probabilistic models and being able to sample fake data, potentially from models trained on real data, is very useful. This fits nicely with the hypothetico–deductive loop: specify your model, perform approximate inference, check your inferences, now check your model, review your modelling choices, and start the process again.
Of course the amazing and diverse practical successes of Bayesian inference — from cracking Enigma and finding Air France Flight 447, to the TrueSkill match-making system — give us great confidence in the utility of the probabilistic approach.
We started the first of these two posts by remarking that it was striking that Michael Jordan, Geoff Hinton, and Zoubin Ghahramani — individuals who had made huge contributions to the probabilistic brand of machine learning — maintained reservations about that very approach. We hope that this post has brought colour to these concerns.
Michael Jordan thought that Bayesians and frequentists should reconcile and operate arm-in-arm. This makes sense since Bayesians develop inference procedures and frequentist methods can be used to evaluate them.
Geoff Hinton’s view was that Bayesian approaches don’t cut it for the problems he is interested in, like object recognition and machine translation. This is understandable since we have seen that Bayesian methods are useful when (1) you’re able to encode knowledge carefully into a detailed probabilistic model; (2) you are uncertain and want to propagate uncertainty accurately; and (3) you can perform accurate approximate inference. Is this really true in situations like object recognition or machine translation? These problems involve learning complex high-dimensional representations from stacks of largely noise-free images, words, or speech, and, to quote Zoubin Ghahramani, “if our goal is to learn a representation (…) then Bayesian inference gets in the way”.^{17}
Zoubin Ghahramani admitted that approximate inference kept him awake at night. We have seen that the arguments that we make to justify the Bayesian approach — Cox’s and Savage’s theorems, Dutch books, et cetera — are practically meaningless because our inference is approximate. The elegant separation of modelling, inference, and decision making steps is, in fact, a tangled web.
So, when we teach the ideas of probabilistic modelling and inference and write boilerplate in our papers, we’d ask that you pause for breath before trotting off the standard justifications. Instead, consider evangelising a little less about how principled the approach is, mention the importance of exploring different modelling assumptions and their consequences, and stress that the quality of approximate inference should be evaluated in a systematic and principled way.
Also quoted by Terenin & Draper (2015). Halpern (1999) constructs a counterexample to Cox’s theorem if the additional technical assumption, Paris’ (1994) density assumption, is omitted. ↩
Van Horn (2003) concludes that “[they] cannot make a compelling case for all of [Cox’s theorem’s] requirements, however, and there remains disagreement as to the desirability of several of them.” ↩
Interestingly, there are two-dimensional theories of probability; see Section 4.1.2 by Bakker (2016) for an overview. ↩
Skyrms (1987, p. 4) states a dynamic Dutch book argument due to David Lewis (reported by Paul Teller; 1973, 1976). Skyrms also says (p. 19) that “(…) not every learning situation is of the kind to which conditionalization applies. The situation may not be of the kind that can be correctly described or even usefully approximated as the attainment of certainty by the agent in some proposition in his degree of belief space. The rule of belief change by probability kinematics on a partition was designed to apply to a much broader class of learning situations than the rule of conditionalization.” In the paper, Skyrms describes the Observation Game, a learning situation where conditionalisation does not apply, where a generalisation of conditionalisation called probability kinematic (Jeffrey, 1965) — essentially, agreement with Bayes’ rule on only a given partition of the probability space — is necessary and, under certain conditions, sufficient to be bulletproof — a strong coherence condition that excludes a Dutch book. ↩
For example, see Example 5.7.2 by Lehmann & Casella (1998; also Makani, 1997). ↩
Savage’s (1962) comment on personal probability that was reproduced at the start of this post captures this sentiment well. ↩
Andrew Gelman and David MacKay discuss this issue here and here. ↩
We previously wrote “the usual choice of a Gaussian prior”, which, unfortunately, is slightly ambiguous. We mean to refer to the usual choices of $\mathbf{w} \sim \mathcal{N}(0, 1)$ or $\mathbf{w} \sim \mathcal{N}(0, 1/\sqrt{n})$. ↩
See also this great answer by Peter Grünwald on StackExchange. ↩
As a solution, Grünwald & van Ommen (2017) propose to replace the posterior by a generalised posterior, which depends on a learning rate. This, however, requires an appropriate choice of the learning rate, obscures the meaning of the likelihood, and, more importantly, invalidates the fundamental justifications which hold for Bayes’ rule. ↩
Often the approximations involve some form of factorisation assumption, but alternatives exist that have different properties. For example, inducing point approximations for Gaussian processes (Titsias, 2009) tend to overestimate uncertainty. ↩
As an attempt to improve the variational approximation, it has been proposed to recalibrate the variational objective by reweighting the KL term (Alemi et al., 2017; Higgings et al., 2017), which is closely related to the cold posterior effect. But by attempting to fix variational inference in this way, we are blurring the lines between Bayesian modelling and end-to-end optimisation, producing cleverly regularised estimators rather than models which reason according to fundamental principles. ↩
A pragmatic solution identifies standard practice, which uses point estimates for the unknowns such as $\hat X_{\text{MAP}}$, as effectively summarising the posterior $p(X \cond D)$ by a Dirac delta function, a probability distribution concentrated on $\hat X_{\text{MAP}}$. David MacKay (2003, Section 33.6) says that, “[f]rom this perspective, any approximating distribution $Q(x;\theta)$, no matter how crummy it is, has to be an improvement on the spike produced by the standard method!” However, “has” is doing a lot of work in this sentence. ↩
For example, Stan, PyMC3, or Turing.jl, but there is more out there. ↩
Micheal Betancourt has a great post about responsible use of MCMC in practice. ↩
Indeed, the idea of decision-making-aware inference has been successful in the meta-learning setting (Gordon et al., 2019). It seems likely that the entanglement of inference and decision making (as a consequence of the intractability of the posterior) could justify the recent trend towards end-to-end systems, which directly optimise for the metric of interest, thereby circumventing these issues. ↩
The development of “Bayesian-inspired” approaches, which blend end-to-end deep learning with ideas from probabilistic modelling and inference is arguably a reaction to these concerns. The goal of these developments is to provide excellent representation learning and strong predictive performance whilst still handling uncertainty, but without claiming strict adherence to the Bayesian dogma. ↩