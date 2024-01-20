An introduction to Flow Matching

Introduction

Flow matching (FM) is a recent generative modelling paradigm which has rapidly been gaining popularity in the deep probabilistic ML community. Flow matching combines aspects from Continuous Normalising Flows (CNFs) and Diffusion Models (DMs), alleviating key issues both methods have. In this blogpost we’ll cover the main ideas and unique properties of FM models starting from the basics.

Generative Modelling

Let’s assume we have data samples $x_1, x_2, \ldots, x_n$ from a distribution of interest $q_1(x)$, whose density is unknown. We’re interested in using these samples to learn a probabilistic model approximating $q_1$. In particular, we want efficient generation of new samples (approximately ) distributed from $q_1$. This task is referred to as generative modelling.

The advancement in generative modelling methods over the past decade has been nothing short of revolutionary. In 2012, Restricted Boltzmann Machines, then the leading generative model, were just about able to generate MNIST digits. Today, state-of-the-art methods are capable of generating high-quality images, audio and language, as well as model complex biological and physical systems. Unsurprisingly, these methods are now venturing into video generation.

Outline

Flow Matching (FM) models are in nature most closely related to (Continuous) Normalising Flows (CNFs). Therefore, we start this blogpost by briefly recapping the core concepts behind CNFs. We then continue by discussing the difficulties of CNFs and how FM models address them.

Normalising Flows

Let $\phi: \mathbb{R}^d \rightarrow \mathbb{R}^d$ be a continuously differentiable function which transforms elements of $\mathbb{R}^d$, with a continously differentiable inverse $\phi^{-1}: \mathbb{R}^d \to \mathbb{R}^d$. Let $q_0(x)$ be a density on $\mathbb{R}^d$ and let $p_1(\cdot)$ be the density induced by the following sampling procedure

which corresponds to transforming the samples of $q_0$ by the mapping $\phi$. Using the change-of-variable rule we can compute the density of $p_1$ as

where the last equality can be seen from the fact that $\phi \circ \phi^{-1} = \Id$ and a simple application of the chain rule. The quantity $\frac{\partial \phi^{-1}}{\partial y}$ is the Jacobian of the inverse map. It is a matrix of size $d\times d$ containing $J_{ij} = \frac{d\phi^{-1}_i}{dx_j}$. Depending on the task at hand, evaluation of likelihood or sampling, the formulation in $\eqref{eq:changevar}$ or $\eqref{eq:changevar-alt}$ is preferred (Friedman, 1987; Chen & Gopinath, 2000).

Example: Transformation of 1D Gaussian variables by linear map Suppose $\phi$ is a linear function of the form \[\phi(x) = ax+b\] \[\phi(x) = ax+b\] with scalar coefficients $a,b\in\mathbb{R}$, and $p$ to be Gaussian with mean $\mu$ and variance $\sigma^2$, i.e. \[p = \mathcal{N}(\mu, \sigma^2).\] \[p = \mathcal{N}(\mu, \sigma^2).\] We know from linearity of Gaussians that the induced $q$ will also be Gaussian distribution but with mean $a\mu+b$ and variance $a^2 \sigma^2$, i.e. \[q = \mathcal{N}(a \mu + b, a^2 \sigma^2).\] \[q = \mathcal{N}(a \mu + b, a^2 \sigma^2).\] More interestingly, though, is to verify that we obtain the same solution by applying the change-of-variable formula. The inverse map is given by \[\phi^{-1}(y) \mapsto \frac{y-b}{a}\] \[\phi^{-1}(y) \mapsto \frac{y-b}{a}\] and it’s derivative w.r.t. $y$ is thus $1/a$ assuming scalar inputs. We thus obtain \[\begin{align*} q(y) &= p\bigg(\frac{y-b}{a}\bigg) \frac{1}{a} \\ &= \mathcal{N}\bigg(\frac{y-b}{a}; \mu, \sigma^2\bigg) \frac{1}{a}\\ &= \frac{1}{\sqrt{2\pi\sigma^2}}\exp \bigg(-\frac{(y/a -b/a-\mu)^2}{2\sigma^2} \bigg)\frac{1}{a}\\ &= \frac{1}{\sqrt{2\pi(a\sigma)^2}}\exp \bigg(-\frac{1}{a^2}\frac{(y-(a\mu+b))^2}{2\sigma^2} \bigg) \\ &= \mathcal{N}\big(y; a\mu+b,a^2\sigma^2\big). \end{align*}\] \[\begin{align*} q(y) &= p\bigg(\frac{y-b}{a}\bigg) \frac{1}{a} \\ &= \mathcal{N}\bigg(\frac{y-b}{a}; \mu, \sigma^2\bigg) \frac{1}{a}\\ &= \frac{1}{\sqrt{2\pi\sigma^2}}\exp \bigg(-\frac{(y/a -b/a-\mu)^2}{2\sigma^2} \bigg)\frac{1}{a}\\ &= \frac{1}{\sqrt{2\pi(a\sigma)^2}}\exp \bigg(-\frac{1}{a^2}\frac{(y-(a\mu+b))^2}{2\sigma^2} \bigg) \\ &= \mathcal{N}\big(y; a\mu+b,a^2\sigma^2\big). \end{align*}\] We have thus verified that the change-of-variables formula can be used to compute the density of a Gaussian variable tranformed by a linear mapping. Often, to simplify notation, we will use the ‘push-forward’ operator $[\phi]_{\#}$ to denote the change in density of applying an invertible map $\phi$ to an input density. That is \[q(y) = ([\phi]_{\#} p)(y) = p\big(\phi^{-1}(y)\big) \det\left[\frac{\partial \phi^{-1}}{\partial y}(y)\right].\] \[q(y) = ([\phi]_{\#} p)(y) = p\big(\phi^{-1}(y)\big) \det\left[\frac{\partial \phi^{-1}}{\partial y}(y)\right].\] If we make the choice of $a = 1$ and $b = \mu$, then we get $\mathcal{N}(\mu, 1)$, as can be seen in the figure below.

Transforming a base distribution $q_0$ into another $p_1$ via a transformation $\phi$ is interesting, yet its direct application in generative modelling is limited. In generative modelling, the aim is to approximate a distribution using only the available samples. Therefore, this task requires the transformation $\phi$ to map samples from a “simple” distribution, such as $\mathcal{N}(0,I)$, to approximately the data distribution. However, a straightforward linear transformation, as in the previous example, is inadequate due to the highly non-Gaussian nature of the data distribution. This brings us to a neural network as a flexible transformation $\phi_\theta$. The key task then becomes optimising the neural net’s parameters $\theta$.

Learning flow parameters by maximum likelihood

Let’s denote the induced parametric density by the flow $\phi_\theta$ as $p_1 \triangleq [\phi_\theta]_{\#}p_0$.

A natural optimisation objective for learning the parameters $\theta \in \Theta$ is to consider maximising the probability of the data under the model:

Parameterising $\phi_\theta$ as a deep neural network leads to several constraints:

How do we enforce invertibility ?

? How do we compute its inverse ?

? How do we compute the jacobian efficiently?

Designing flows $\phi$ therefore requires trading-off expressivity (of the flow and thus of the probabilistic model) with the above mentioned considerations so that the flow can be trained efficiently.

Residual flow

In particular, computing the determinant of the Jacobian is in general very expensive (as it would require $d$ automatic differentation passes in the flow) so we impose structure in $\phi$.

Full-rank residual (Behrmann et al., 2019; Chen et al., 2010)

Expressive flows relying on a residual connection have been proposed as an interesting middle-ground between expressivity and efficient determinant estimation. They take the form:

where unbiased estimate of the log likelihood can be obtained. As opposed to auto-regressive flows (Huang et al., 2018, Larochelle and Murray, 2011, Papamakarios et al., 2017), and low-rank residual normalising flows (Van Den Berg et al. 2018), the update in \eqref{eq:full_rank_res} has full rank Jacobian, typically leading to more expressive transformations.

We can also compose such flows to get a new flow:

This can be a useful way to construct move expressive flows! The model’s log-likelihood is then given by summing each flow’s contribution

with $x_k = \phi_K^{-1} \circ \ldots \circ \phi^{-1}_{k} (y)$.

Continuous time limit

As mentioned previously, residual flows are transformations of the form $\phi(x) = x + \delta \ u(x)$ for some $\delta > 0$ and Lipschitz residual connection $u$. We can re-arrange this to get

which is looking awfully similar to $u$ being a derivative. In fact, letting $\delta = 1/K$ and taking the limit $K \rightarrow \infty$ under certain conditions, a composition of residual flows $\phi_K \circ \cdots \circ \phi_2 \circ \phi_1$ is given by an ordinary differential equation (ODE):

where the flow of the ODE $\phi_t: [0,1]\times\mathbb{R}^d\rightarrow\mathbb{R}^d$ is defined such that

That is, $\phi_t$ maps initial condition $x_0$ to the ODE solution at time $t$:

Continuous change-in-variables

Of course, this only defines the map $\phi_t(x)$; for this to be a useful normalising flow, we still need to compute the log-abs-determinant of the Jacobian!

As it turns out, the density induced by $\phi_t$ (or equivalently $u_t$) can be computed via the following equation

This statement on the time-evolution of $p_t$ is generally known as the Transport Equation. We refer to $p_t$ as the probability path induced by $u_t$.

Computing the total derivative (as $x_t$ also depends on $t$) in log-space yields

resulting in the log density

Parameterising a vector field neural network $u_\theta: \mathbb{R}_+ \times \mathbb{R^d} \rightarrow \mathbb{R^d}$ therefore induces a parametric log-density

In practice, to compute $\log p_t$ one can either solve both the time evolution of $x_t$ and its log density $\log p_t$ jointly

or solve only for $x_t$ and then use quadrature methods to estimate $\log p_t(x_t)$.

Feeding this (joint) vector field to an adaptive step-size ODE solver allows us to control both the error in the sample $x_t$ and the error in the $\log p_t(x)$.

One may legitimately wonder why should we bother with such time-continuous flows versus discrete residual flows. There are a couple of benefits:

CNFs can be seen as an automatic way of choosing the number of residual flows $K$ to use, which would otherwise be a hyperparameter we would have to tune. In the time-continuous setting, we can choose an error threshold $\epsilon$ and the adapative solver would give us a the discretisation step size $\delta$, effectively yielding $K = 1/\delta$ steps. Using an explicit first order solver, each step is of the form $x \leftarrow x + \delta \ u_\theta(t_k, x)$, akin to a residual flow, where the residual connection parameters $\theta$ are shared for each discretisation step, since $u_\theta$ is amortised over $t$, instead of having a different $\theta_k$ for each layer. In residual flows, during training we need to ensure that $u_\theta$ is $1 / \delta$ Lipschitz; otherwise the resulting flow will not be invertible and thus not a valid normalising flow. With CNFs, we still require the vector field $u_\theta(t, x)$ to be Lipschitz in $x$, but we don’t have to worry about exactly what this Lipschitz constant is, which is obviously much easier to satisfy and enforce in the neural architecture.

Now that you know why CNFs are cool, let’s have a look at what such a flow would be for a simple example.

Example: Gaussian to a Gaussian (1D) Let’s come back to our earlier example of mapping a 1D Gaussian to another one with different mean. In contrast to previously where we derived a ‘one-shot’ (i.e. discrete) flow bridging between the two Gaussians, we now aim to derive a time-continuous flow $\phi_t$ which would correspond to the time integrating a vector field $u_t$. We have the following two distributions \[\begin{equation*} p_0 = \mathcal{N}(0, 1) \quad \text{and} \quad p_1 = \mathcal{N}(\mu, 1). \end{equation*}\] \[\begin{equation*} p_0 = \mathcal{N}(0, 1) \quad \text{and} \quad p_1 = \mathcal{N}(\mu, 1). \end{equation*}\] It’s not difficult to see that we can continuously bridge between these with a simple linear transformation \[\begin{equation*} \phi(t, x_0) = x_0 + \mu t \end{equation*}\] \[\begin{equation*} \phi(t, x_0) = x_0 + \mu t \end{equation*}\] which is visualized in the figure below. By linearity, we know that every marginal $p_t$ is a Gaussian, and so \[\begin{equation*} \mathbb{E}_{p_0}[\phi_t(x_0)] = \mu t \end{equation*}\] \[\begin{equation*} \mathbb{E}_{p_0}[\phi_t(x_0)] = \mu t \end{equation*}\] which, in particular, implies that $\mathbb{E}_{p_0}[\phi_1(x_0)] = \mu = \mathbb{E}_{p_1}[x_1]$. Similarly, we have \[\begin{equation*} \mathrm{Var}_{p_0}[\phi_t(x_0)] = 1 \quad \implies \quad \mathrm{Var}_{p_0}[\phi_1(x_0)] = 1 = \mathrm{Var}_{p_1}[x_1] \end{equation*}\] \[\begin{equation*} \mathrm{Var}_{p_0}[\phi_t(x_0)] = 1 \quad \implies \quad \mathrm{Var}_{p_0}[\phi_1(x_0)] = 1 = \mathrm{Var}_{p_1}[x_1] \end{equation*}\] Hence we have a probability path $p_t = \mathcal{N}(\mu t, 1)$ bridging $p_0$ and $p_1$. Now let’s determine what the vector field $u_t(x)$ would be in this case. As mentioned earlier, $u(t, x)$ should satisfy the following \[\begin{equation*} \dv{\phi_t}{t}(x_0) = u_t \big( \phi_t(x_0) \big). \end{equation*}\] \[\begin{equation*} \dv{\phi_t}{t}(x_0) = u_t \big( \phi_t(x_0) \big). \end{equation*}\] Since we have already specified $\phi$, we can plug it in on the left hand side to get \[\begin{equation*} \dv{\phi_t}{t}(x_0) = \dv{t} \big( x_0 + \mu t \big) = \mu \end{equation*}\] \[\begin{equation*} \dv{\phi_t}{t}(x_0) = \dv{t} \big( x_0 + \mu t \big) = \mu \end{equation*}\] which gives us \[\begin{equation*} \mu = u_t \big( x_0 + \mu t \big). \end{equation*}\] \[\begin{equation*} \mu = u_t \big( x_0 + \mu t \big). \end{equation*}\] The above needs to hold for all $t \in [0, 1]$, and so it’s not too difficult to see that one such solution is the constant vector field \[\begin{equation*} u_t(x) = \mu. \end{equation*}\] \[\begin{equation*} u_t(x) = \mu. \end{equation*}\] We could of course have gone the other way, i.e. define the $u_t$ such that $p_0 \overset{u_t}{\longleftrightarrow} p_1$ and derive the corresponding $\phi_t$ by solving the ODE.

Training CNFs

Similarly to any flows, CNFs can be trained by maximum log-likelihood

where the expectation is taken over the data distribution and $p_1$ is the parameteric distribution. This involves integrating the time-evolution of samples $x_t$ and log-likelihood $\log p_t$, both terms being a function of the parametric vector field $u_{\theta}(t, x)$. This requires

⚠️ Expensive numerical ODE simulations at training time!

⚠️ Estimators for the divergence to scale nicely with high dimension.

CNFs are very expressive as they parametrise a large class of flows, and therefore of probability distributions. Yet training can be extremely slow due to the ODE integration at each iteration. One may wonder whether a ‘simulation-free’, i.e. not requiring any integration, training procedure exists for training these CNFs.

Flow matching

And that is exactly where Flow Matching (FM) comes in!

Flow matching is a simulation-free way to train CNF models where we directly formulate a regression objective w.r.t. the parametric vector field $u_\theta$ of the form

In the equation above, $u(t, x)$ would be a vector field inducing a probability path (or bridge) $p_t$ interpolating the reference $p_0$ to $p_1$, i.e.

In words: we’re just performing regression on $u_t(x)$ for all $t \in [0, 1]$.

Of course, this requires knowledge of a valid $u(t, x)$, and if we already have access to $u_t$, there’s no point in learning an approximation $u_{\theta}(t, x)$ in the first place! But as we will see in the next section, we can leverage this formulation to construct a useful target for $u_{\theta}(t, x)$ witout having to compute explicitly $u(t, x)$.

This is where Conditional Flow Matching (CFM) comes to the rescue.

Non-uniqueness of vector field We say a valid $u_t$ because there is no unique vector field $u_t$; there are indeed many valid choices for $u_t$ inducing maps $p_0 \overset{\phi}{\longleftrightarrow} p_1$ as illustrated in the figure below. As we will see in what follows, in practice we have to pick a particular target $u_t$, which has practical implications.

Conditional Flows

First, let’s remind ourselves that the transport equation relates a vector field $u_t$ to (the time evolution of) a probability path $p_t$

thus constructing $p_t$ or $u_t$ is equivalent. One key idea (Lipman et al., 2023 and Albergo & Vanden-Eijnden, 2022) is to express the probability path as a marginal over a joint involving a latent variable $z$: $p_t(x_t) = \int p(z) ~p_{t\mid z}(x_t\mid z) \textrm{d}z$. The $p_{t\mid z}(x_t\mid z)$ term being a conditional probability path, satisfying some boundary conditions at $t=0$ and $t=1$ so that $p_t$ be a valid path interpolating between $q_0$ and $q_1$. In addition, as opposed to the marginal $p_t$ , the conditional $p_{t\mid1}$ could be available in closed-form.

In particular, as we have access to data samples $x_1 \sim q_1$, it sounds pretty reasonable to condition on $z=x_1$, leading to the following marginal probabilithy path

In this setting, the conditional probability path $p_{t\mid 1}$ needs to satisfy the boundary conditions

with $\sigmamin > 0$ small, and for whatever reference $p_0$ we choose, typically something “simple” like $p_0(x) = \mathcal{N}(x; 0, I)$, as illustrated in the figure below.

The conditional probability path also satisfies the transport equation with the conditional vector field $u_t(x \mid x_1)$:

Lipman et al. (2023) introduced the notion of Conditional Flow Matching (CFM) by noticing that this conditional vector field $u_t(x \mid x_1)$ can express the marginal vector $u_t(x)$ of interest via the conditional probability path $p_{t\mid 1}(x_t\mid x_1)$ as

To see why this $u_t$ the same the vector field as the one defined earlier, i.e. the one generating the (marginal) probability path $p_t$, we need to show that the expression above for the marginal vector field $u_t(x)$ satisfies the transport equation

Writing out the left-hand side, we have

where in the $\hlone{\text{first highlighted step}}$ we used \eqref{eq:continuity-cond} and in the $\hltwo{\text{last highlighted step}}$ we used the expression of $u_t(x)$ in \eqref{eq:cf-from-cond-vf}.

The relation between $\phi_t(x_0)$, $\phi_t(x_0 \mid x_1)$ and their induced densities are illustrated in the Figure 9 below. And since $\phi_t(x_0)$ and $\phi_t(x_0 \mid x_1)$ are solutions corresponding to the vector fields $u_t(x)$ and $u_t(x \mid x_1)$ with $x(0) = x_0$, Figure 9 is equivalent to Figure 10, but note the difference in the expectation taken to go from $u_t(x_0 \mid x_1) \longrightarrow u_t(x_0)$ compared to $\phi_t(x_0 \mid x_1) \longrightarrow \phi_t(x_0)$.

Gaussian to Gaussian (2D) using a conditional flow Let’s try to gain some intuition behind \eqref{eq:cf-from-cond-vf} and the relation between $u_t(x)$ and $u_t(x \mid x_1)$. We do so by looking at the following scenario \[\begin{equation} \tag{G-to-G} \label{eq:g2g} \begin{split} p_0 = \mathcal{N}([-\mu, 0], I) \quad & \text{and} \quad p_1 = \mathcal{N}([+\mu, 0], I) \\ \text{with} \quad \phi_t(x_0 \mid x_1) &= (1 - t) x_0 + t x_1 \end{split} \end{equation}\] \[\begin{equation} \tag{G-to-G} \label{eq:g2g} \begin{split} p_0 = \mathcal{N}([-\mu, 0], I) \quad & \text{and} \quad p_1 = \mathcal{N}([+\mu, 0], I) \\ \text{with} \quad \phi_t(x_0 \mid x_1) &= (1 - t) x_0 + t x_1 \end{split} \end{equation}\] with $\mu = 10$ unless otherwise specified. We’re effectively transforming a Gaussian to another Gaussian using a simple time-linear map, as illustrated in the following figure. In the end, we’re really just interested in learning the marginal paths $\phi_t(x_0)$ for initial points $x_0$ that are probable under $p_0$, which we can then use to generate samples $x_1 = \phi_1(x_0)$. In this simple example, we can obain closed-form expressions for $\phi_t(x_0)$ corresponding to the conditional paths $\phi_t(x_0 \mid x_1)$ of \eqref{eq:g2g}, as visualised below. With that in mind, let’s pick a random initial point $x_0$ from $p_0$, and then compare a MC estimator for $u_t(x_0)$ at different values of $t$ along the path $\phi_t(x_0)$, i.e. we’ll be looking at \[\begin{equation*} \begin{split} u_t \big( \phi_t(x_0) \big) &= \E_{p_{1 \mid t}}\left[u_t \big( \phi_t(x_0) \mid x_1 \big)\right] \\ &\approx \frac{1}{n} \sum_{i = 1}^n u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big) \ \text{with } x_1^{(i)} \sim p_{1 \mid t}(x_1 \mid \phi_t(x_0)). \end{split} \end{equation*}\] \[\begin{equation*} \begin{split} u_t \big( \phi_t(x_0) \big) &= \E_{p_{1 \mid t}}\left[u_t \big( \phi_t(x_0) \mid x_1 \big)\right] \\ &\approx \frac{1}{n} \sum_{i = 1}^n u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big) \ \text{with } x_1^{(i)} \sim p_{1 \mid t}(x_1 \mid \phi_t(x_0)). \end{split} \end{equation*}\] In practice we don’t have access to the posterior \(p_{1 \mid t}(x_1 \mid x_t)\), but in this specific setting we do have closed-form expressions for everything (Albergo & Vanden-Eijnden, 2022), and so we can visualise the marginal vector field \(u_t\big( \phi_t(x_0)\big)\) and the conditional vector fields \(u_t \big( \phi_t(x_0) \mid x_1^{(i)} \big)\) for all our “data” samples \(x_1^{(i)}\) and see how they compare. This is shown in the figure below. From the above figures, we can immediately see how for small $t$, i.e. near 0, the posterior $p_{1 \mid t}(x_1 \mid x_t)$ is quite scattered so the marginalisation giving $u_t$ involves many equally likely data samples $x_1$. In contrast, when $t$ increases and get closer to 1, $p_{1 \mid t}(x_1 \mid x_t)$ gets quite concentrated over much fewer samples $x_1$.

Moreover, equipped with the knowledge of \eqref{eq:cf-from-cond-vf}, we can replace

where $u_t(x) = \mathbb{E}_{x_1 \sim p_{1 \mid t}} \left[ u_t(x \mid x_1) \right]$, with an equivalent loss regressing the conditional vector field $u_t(x \mid x_1)$ and marginalising $x_1$ instead:

These losses are equivalent in the sense that

which implies that we can use \({\mathcal{L}}_{\text{CFM}}\) instead to train the parametric vector field $u_{\theta}$. We defer the full proof to the footnote, but show the key idea below. By developing the squared norm in both losses, we can easily show that the squared terms are equal or independent of $\theta$. Let’s develop the inner product term for \({\mathcal{L}}_{\text{FM}}\) and show that it is equal to the inner product of \({\mathcal{L}}_{\text{CFM}}\):

where in the $\hltwo{\text{first highlighted step}}$ we used the expression of $u_t(x)$ in \eqref{eq:cf-from-cond-vf}.

The benefit of the CFM loss is that once we define the conditional probability path $p_t(x \mid x_1)$, we can construct an unbiased Monte Carlo estimator of the objective using samples $\big( x_1^{(i)} \big)_{i = 1}^n$ from the data target $q_1$!

This estimator can be efficiently computed as it involves an expectation over the joint $q_1(x_1)p_t(x \mid x_1)$ , of the conditional vector field $u_t (x \mid x_1)$ both being available as opposed to the marginal vector field $u_t$ which involves an expectation over the posterior $p_{1 \mid t}(x_1 \mid x)$.

We note that, as opposed to the log-likelihood maximisation loss of CNFs which does not put any preference over which vector field $u_t$ can be learned, the CFM loss does specify one via the choice of a conditional vector field, which will be regressed by the neural vector field $u_\theta$.

Gaussian probability paths

Let’s now look at practical example of conditional vector field and the corresponding probability path. Suppose we want conditional vector field which generates a path of Gaussians, i.e.

for some mean $\mu_t(x_1)$ and standard deviation $\sigma_t(x_1)$.

One conditional vector field inducing the above-defined conditional probability path is given by the following expression:

as shown in the proof below.

Proof We have $$ \begin{equation*} \phi_t(x \mid x_1) = \mu_t(x_1) + \sigma_t(x_1) x \end{equation*} $$ and we want to determine $u_t(x \mid x_1)$ such that $$ \begin{equation*} \frac{\dd}{\dd t} \phi_t(x) = u_t \big( \phi_t(x) \mid x_1 \big) \end{equation*} $$ First note that the LHS is $$ \begin{equation*} \begin{split} \frac{\dd{}}{\dd{} t} \phi_t(x) &= \frac{\dd{}}{\dd{} t} \bigg( \mu_t(x_1) + \sigma_t(x_1) x \bigg) \\ &= \dot{\mu_t}(x_1) + \dot{\sigma_t}(x_1) x \end{split} \end{equation*} $$ so we have $$ \begin{equation*} \dot{\mu_t}(x_1) + \dot{\sigma_t}(x_1) x = u_1 \big( \phi_t(x \mid x_1) \mid x_1 \big) \end{equation*} $$ Suppose that $u_1$ is of the form $$ \begin{equation*} u_1\big( \phi_t(x) \mid x_1\big) = h\big(t, \phi_t(x), x_1\big) \dot{\mu_t}(x_1) + g\big(t, \phi_t(x), x_1\big) \dot{\sigma_t}(x_1) \end{equation*} $$ for some functions $h$ and $g$. Reading of the components from the previous equation, we then see that we require $$ \begin{equation*} h\big(t, \phi_t(x), x_1\big) = 1 \quad \text{and} \quad g(t, \phi_t(x), x_1) = x \end{equation*} $$ The simplest solution to the above is then just $$ \begin{equation*} h(t, x, x_1) = 1 \end{equation*} $$ i.e. constant function, and $$ \begin{equation*} g(t, x, x_1) = \phi_t^{-1}(x) = \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} \end{equation*} $$ such that $$ \begin{equation*} g\big(t, \phi_t(x), x_1) = \phi_t^{-1} \big( \phi_t(x) \big) = x \end{equation*} $$ resulting in $$ \begin{equation*} u_t \big( x \mid x_1 \big) = \dot{\mu_t}(x_1) + \dot{\sigma_t}(x_1) \bigg( \frac{x - \mu_t(x_1)}{\sigma_t(x_1)} \bigg) \end{equation*} $$ as claimed.

Example: Linear interpolation A simple choice for the mean $\mu_t(x_1)$ and std. $\sigma_t(x_1)$ is the linear interpolation for both, i.e. \[\begin{align*} \hlone{\mu_t(x_1)} &\triangleq t x_1 \quad &\text{and} \quad \hlthree{\sigma_t(x_1)} &\triangleq (1 - t) + t \sigmamin \\ \hltwo{\dot{\mu}_t(x_1)} &\triangleq x_1 \quad &\text{and} \quad \hlfour{\dot{\sigma}_t(x_1)} &\triangleq -1 + \sigmamin \end{align*}\] \[\begin{align*} \hlone{\mu_t(x_1)} &\triangleq t x_1 \quad &\text{and} \quad \hlthree{\sigma_t(x_1)} &\triangleq (1 - t) + t \sigmamin \\ \hltwo{\dot{\mu}_t(x_1)} &\triangleq x_1 \quad &\text{and} \quad \hlfour{\dot{\sigma}_t(x_1)} &\triangleq -1 + \sigmamin \end{align*}\] so that \[\begin{equation*} \big( {\hlone{\mu_0(x_1)}} + {\hlthree{\sigma_0(x_1)}} x_1 \big) \sim p_0 \quad \text{and} \quad \big( {\hlone{\mu_1(x_1)}} + {\hlthree{\sigma_1(x_1)}} x_1 \big) \sim \mathcal{N}(x_1, \sigmamin^2 I) \end{equation*}\] \[\begin{equation*} \big( {\hlone{\mu_0(x_1)}} + {\hlthree{\sigma_0(x_1)}} x_1 \big) \sim p_0 \quad \text{and} \quad \big( {\hlone{\mu_1(x_1)}} + {\hlthree{\sigma_1(x_1)}} x_1 \big) \sim \mathcal{N}(x_1, \sigmamin^2 I) \end{equation*}\] In addition, letting $p_0 = \mathcal{N}([-\mu, 0], I)$ and $p_1 = \mathcal{N}([+\mu, 0], I)$ for some $\mu > 0$, we’re back to the \ref{eq:g2g} example from earlier. We can then plug this choice of $\mu_t(x_1)$ and $\sigma_t(x_1)$ into \eqref{eq:gaussian-path} to obtain the conditional vector field, writing $\hlthree{\sigma_t(x_1)} = 1 - (1 - \sigmamin) t$ to make our lives simpler, \[\begin{equation*} \begin{split} u_t(x \mid x_1) &= \frac{\hlfour{- (1 - \sigmamin)}}{\hlthree{1 - (1 - \sigmamin) t}} (x - \hlone{t x_1}) + \hltwo{x_1} \\ &= \frac{1}{(1 - t) + t \sigmamin} \bigg( - (1 - \sigmamin) (x - t x_1) + \big(1 - (1 - \sigmamin) t \big) x_1 \bigg) \\ &= \frac{1}{(1 - t) + t \sigmamin} \bigg( - (1 - \sigmamin) x + x_1 \bigg) \\ &= \frac{x_1 - (1 - \sigmamin) x}{1 - (1 - \sigmamin) t}. \end{split} \end{equation*}\] \[\begin{equation*} \begin{split} u_t(x \mid x_1) &= \frac{\hlfour{- (1 - \sigmamin)}}{\hlthree{1 - (1 - \sigmamin) t}} (x - \hlone{t x_1}) + \hltwo{x_1} \\ &= \frac{1}{(1 - t) + t \sigmamin} \bigg( - (1 - \sigmamin) (x - t x_1) + \big(1 - (1 - \sigmamin) t \big) x_1 \bigg) \\ &= \frac{1}{(1 - t) + t \sigmamin} \bigg( - (1 - \sigmamin) x + x_1 \bigg) \\ &= \frac{x_1 - (1 - \sigmamin) x}{1 - (1 - \sigmamin) t}. \end{split} \end{equation*}\] Below you can see the difference between $\phi_t(x_0)$ (top figure) and $\phi_t(x_0 \mid x_1)$ (bottom figure) for pairs $(x_0, x_1)$ with $x_0 \sim p_0$ and $x_1 = \phi_t(x_0)$. The paths are coloured by the sign of the 2nd vector component of $x_0$ to more clearly highlight the difference between the marginal and conditional flows.

But is CFM really all rainbows and unicorns?

Unfortunately not, no. There are two issues arising from crossing conditional paths. We will explain this just after, but now we stress that this leads to

Non-straight marginal paths $\Rightarrow$ ODE hard to integrate $\Rightarrow$ slow sampling at inference. Many possible $x_1$ for a noised $x_t$ $\Rightarrow$ high CFM loss variance $\Rightarrow$ slow training convergence.

To get a better understanding of what these two points above, let’s revisit the \ref{eq:g2g} example once more. As we see in the figures below, realizations of the conditional vector field $u_t(x \mid x_1)$, i.e. sampling from the process

result in paths that are quite different from the marginal paths as illustrated in the figures below.

In particular, we can see that the marginal paths $\phi_t(x)$ do not cross; this is indeed just the uniqueness property of ODE solutions. A realization of the conditional vector field $u_t(x \mid x_1)$ also exhibits the “non-crossing paths” property, similar to the marginal flows $\phi_t(x)$, however paths $\phi_t(x \mid x_1)$ corresponding to different realizations $x_1 \sim q_1$ may intersect, as highlighted in the figure above.

Consider two highlighted paths in the visualization of $u_t(x \mid x_1)$, with data samples $\hlone{x_1^{(1)}}$ and $\hlthree{x_1^{(2)}}$. When learning a parameterized vector field $u_{\theta}(t, x)$ via stochastic gradient descent (SGD), we approximate the CFM loss as:

where $t \sim \mathcal{U}[0, 1]$, $\hlone{x_1^{(1)}}, \hlthree{x_1^{(2)}} \sim q_1$, and $\hlone{x_t^{(1)}} \sim p_t(\cdot \mid \hlone{x_1^{(1)}}), \hlthree{x_t^{(2)}} \sim p_t(\cdot \mid \hlthree{x_1^{(2)}})$. We compute the gradient with respect to $\theta$ for a gradient step.

In such a scenario, we’re attempting to align $u_{\theta}(t, x)$ with two different vector fields whose corresponding paths are impossible under the marginal vector field $u(t, x)$ that we’re trying to learn! This fact can lead to increased variance in the gradient estimate, and thus slower convergence.

In slightly more complex scenarios, the situation becomes even more striking. Below we see a nice example from Liu et al. (2022) where our reference and target are two different mixture of Gaussians in 2D differing only by the sign of the mean in the x-component. Specifically,

where we set $\mu = 10$, unless otherwise specified.

Here we see that marginal paths (bottom figure) end up looking very different from the conditional paths (top figure). Indeed, at training time paths may intersect, whilst at sampling time they cannot (due to the uniqueness of the ODE solution). As such we see on the bottom plot that some (marginal) paths are quite curved and would therefore require a greater number of discretisation steps from the ODE solver during inference.

We can also see how this leads to a significant variance of the CFM loss estimate for $t \approx 0.5$ in the figure below. More generally, samples from the reference distribution which are arbitrarily close to eachothers can be associated with either target modes, leading to high variance in the vector field regression loss.

An intuitive solution would be to associate data samples with reference samples which are close instead of some arbitrary pairing. We’ll detail this idea next via the concept of couplings and optimal transport.

Coupling

So far we have constructed the vector field $u_t$ by conditioning and marginalising over data points $x_1$. This is referred to as a one-sided conditioning, where the probability path is constructed by marginalising over $z=x_1$:

e.g. \(p(x_t \mid x_1) = \mathcal{N}(x_t \mid x_1, (1-t)^2)\).

Yet, more generally, we can consider conditioning and marginalising over latent variables $z$, and minimising the following loss:

As suggested in Liu et al. (2023), Tong et al. (2023), Albergo & Vanden-Eijnden (2022) and Pooladian et al. (2023) one can condition on both endpoints $z=(x_1, x_0)$ of the process, referred as two-sided conditioning. The marginal probability path is defined as:

The following boundary condition on $p_t(x_t \mid x_1, x_0)$: $p_0(\cdot \mid x_1, x_0)=\delta_{x_0}$ and $p_1(\cdot \mid x_1, x_0) = \delta_{x_1}$ is required so that the marginal has the proper conditions $p_0 = q_0$ and $p_1 = q_1$.

For instance, a deterministic linear interpolation gives $p(x_t \mid x_0, x_1) = \delta_{(1-t)} x_0 + t x_1(x_t)$ and the simplest choice regarding the coupling $z = (x_1, x_0)$ is to consider independent samples: $q(x_1, x_0) = q_1(x_1) q_0(x_0)$.

One main advantage is that this allows for non Gaussian reference distribution $q_0$. Choosing a standard normal as noise distribution $q(x_0) = \mathcal{N}(0, \mathrm{I})$ we recover the same one-sided conditional probability path as earlier:

Optimal Transport (OT) coupling

Now let’s go back to the idea of not using an independent coupling (i.e. pairing) but instead to correlate pairs $(x_1, x_0)$ with a joint $q(x_1, x_0)

eq q_1(x_1) q_0(x_0)$. Tong et al. (2023) and Pooladian et al. (2023) suggest using the optimal transport coupling

which minimises the optimal transport (i.e. Wasserstein) cost (Monge, 1781, Peyré and Cuturi 2020). The OT coupling $\pi$ associates samples $x_0$ and $x_1$ such that the total distance is minimised.

This OT coupling is illustrated in the right hand side of the figure below, adapted from Tong et al. (2023). In contrast to the middle figure which an independent coupling, the OT one does not have paths that cross. This leads to lower training variance and faster sampling.

In practice, we cannot compute the optimal coupling $\pi$ between $x_1 \sim q_1$ and $x_0 \sim q_0$, as algorithms solving this problem are only known for finite distributions. In fact, finding a map from $q_0$ to $q_1$ is the generative modelling problem that we are trying to solve in the first place!

Tong et al. (2023) and Pooladian et al. (2023) propose to approximate the OT coupling $\pi$ by computing such optimal coupling only over each mini-batch of data and noise samples, coined mini-batch OT (Fatras et al., 2020). This is scalable as for finite collection of samples the OT problem can be computed with quadratic complexity via the Sinkhorn algorithm (Peyre and Cuturi, 2020). This results in a joint distribution $\gamma(i, j)$ over “inputs” \(\big(x_0^{(i)}\big)_{i=1,\dots,B}\) and “outputs” \(\big(x_1^{(j)}\big)_{j=1,\dots,B}\) such that the expected distance is (approximately) minimised. Finally, to construct a mini-batch from this $\gamma$ which we can subsequently use for training, we can either compute the expectation wrt. $\gamma(i, j)$ by considering all $n^2$ pairs (in practice, this can often boil down to only needing to consider $n$ disjoint pairs) or sample a new collection of training pairs $(x_0^{(i’)}, x_1^{(j’)})$ with $(i’, j’) \sim \gamma$.

For example, we can apply this to the \eqref{eq:g2g} example from before, which almost completely removes the crossing paths behaviour described earlier, as can be seen in the figure below.

We also observe similar behavior when applying this the more complex example \eqref{eq:mog2mog}, as can be seen in the figure below.

All in all, making use of mini-batch OT seems to be a strict improvement over the uniform sampling approach to constructing the mini-batch in the above examples and has been shown to improve practical performance in a wide range of applications (Tong et al., 2023; Klein et al., 2023).

It’s worth noting that in \eqref{eq:ot} we only considered choosing the coupling $\gamma(i, j)$ such that we minimize the expected squared Euclidean distance. This works well in the examples \eqref{eq:g2g} and \eqref{eq:mog2mog}, but we could also replace squared Euclidean distance with some other distance metric when constructing the coupling $\gamma(i, j)$. For example, if we were modeling molecules using CNFs, it might also make sense to pick $(i, j)$ such that $x_0^{(i)}$ and $x_1^{(j)}$ are also rotationally aligned as is done in the work of Klein et al. (2023).

Quick Summary

In short, we’ve shown that flow matching is an efficient approach to training continuous normalising flows (CNFs), by directly regressing over the vector field instead of explicitly training by maximum likelihood. This is enabled by constructing the target vector field as the marginalisation of simple conditional vector fields which (marginally) interpolate between the reference and data distribution, but crucially for which we can evaluate and integrate over time. A neural network parameterising the vector field can then be trained by regressing over these conditional vector fields. Similarly to CNFs, samples can be obtained at inference time by solving the ODE associated with the neural vector field.

In this post we have not talked about diffusion (i.e. score based) models on purpose as they are not necessary for understanding flow matching. Yet these are deeply related and even exactly the same in some settings. We are planning to explore these connections, along with generalisations in a follow-up post!

