June 21, 2020

Probabilistic modeling using normalizing flows pt.1

Probabilistic models give a rich representation of observed data and allow us to quantify uncertainty, detect outliers, and perform simulations. Classic probabilistic modeling require us to model our domain with conditional probabilities, which is not always feasible. This is particularly true for high-dimensional data such as images or audio. In these scenarios, we would like to learn the data distribution without all the modeling assumptions. normaizing flows is a powerful class of models that allow us to do just that, without resorting to approximations. They work by learning a sequence of transformations, or flows, that transforms a simple distribution to one that fits the observed data. There are several different flavours of normalizing flows, and in this blog article we are going to implement them using affine coupling layers in PyTorch. To keep things focused, this article will cover the theory and the model implementation, and in a follow-up article will see how the model works in practice by fitting it to some data.

Normalizing flows

In a normalizing flows model we define an observed stochastic variable \( x \in \mathbb{R}^D, x \sim p_X, \) a latent stochastic variable \( z \in \mathbb{R}^D, z \sim p_Z \) and a bijective and differentiable function \( z = f(x): \mathbb{R}^D \mapsto \mathbb{R}^D \) with inverse \( g = f^{-1} \). The change of variable formula tells us that the (log) densities \(p_X\) and \(p_Z\) are related through

\begin{equation} \begin{split} \log p_X(x) & = \log p_Z(z) & + \log \left \vert \det \frac{dz}{dx} \right \vert \iff \\ \log p_X(x) & = \log p_Z(f(x)) & + \log \left \vert \det \frac{df(x)}{dx} \right \vert, \end{split} \end{equation}

where \( \frac{df(x)}{dx} \) is the Jacobian of \(f\) evaluated at \(x\). What’s great about this formula is that given \( p_Z \) and \( f \), we have an exact expression for \(p_X\) expressed in \(x\); we do not have to resort to methods like maximizing ELBO or Monte Carlo sampling, but can instead train by directly maximizing \(\log p_X(x)\). We will use this result to design an algorithm that can learn \( p_X \) from data, which means we have to specify \( p_Z \) and \( f \). For our purposes \( p_Z \) can be any distribution we can evaluate the pdf of and sample from, and it is common to let \( p_Z = \mathcal{N}\). \(f\) on the other hand, will be the part of the model we learn from data. However, we can not learn just any \(f\). This is because computing \( \det \frac{df(x)}{dx} \) is \( \mathcal{O}(D^3) \) in the general case, so a naive design (i.e. just plugging a neural network in there) would make inference intractable for high-dimensional data. Fortunately, this complexity can be reduced by limiting the interactions between dimensions in \( f \), but this has the unwanted side-effect of reducing models expressivity, making it worse at modeling complex distributions. This trade-off hints of the central problem in normalizing flows research: How can we design expressive bijections \( f \) with cheap-to-compute Jacobians? It turns out that this is much easier if we decompose \( f \) into \(n\) smaller functions or flows \( f = f_n \circ \dots \circ f_2 \circ f_1 \). The Jacobian for function composition is given by

\begin{equation} \frac{df(x)}{dx} = \prod_{i=1}^n \frac{df_i(x_{i-1})}{dx_{i-1}} \end{equation}

which means its log determinant is

\begin{equation} \log \left \vert \det \frac{df(x)}{dx} \right \vert = \log \left \vert \det \prod_{i=1}^n \frac{df_i(x_{i-1})}{dx_{i-1}} \right \vert = \sum_{i=1}^n \log \left \vert \det \frac{df_i(x_{i-1})}{dx_{i-1}} \right \vert, \end{equation}

where \( x_{i} \) denotes the output of the \(i\):th function and \(x_0 = x\). For \(f\) to be a differentiable bijection, we require every \(f_i\) to be differentiable and bijective, but apart from that we can choose the flows as we please.

Figure 1: Illustration of the composition of normalizing flows. Each transformation (f_i) modifies (p_X) to be more and more like (p_Z). Commonly, ( p_Z = mathcal{N} ), so we refer to (f) as a normalizing flow, while its inverse (g) is known as a generative flow.

Figure 1: Illustration of the composition of normalizing flows. Each transformation (f_i) modifies (p_X) to be more and more like (p_Z). Commonly, ( p_Z = mathcal{N} ), so we refer to (f) as a normalizing flow, while its inverse (g) is known as a generative flow.

Since the Jacobian only depends on the output of the previous flow, we can compute it alongside \(f_i(x)\) with no additional overhead. We will see how to design the flows \(f_i\) shortly, but lets take a top-down view when implementing this and start with a class for the entire normalizing flow.

  class NormalizingFlow(nn.Module):

    def __init__(self, latent: Distribution, flows: List[nn.Module]):
      super(NormalizingFlow, self).__init__()
      self.latent = latent
      self.flows = flows

    def latent_log_prob(self, z: torch.Tensor) -> torch.Tensor:
      return self.latent.log_prob(z)

    def latent_sample(self, num_samples: int = 1) -> torch.Tensor:
      return self.latent.sample((num_samples,))

    def sample(self, num_samples: int = 1) -> torch.Tensor:
      '''Sample a new observation x by sampling z from
      the latent distribution and pass through g.'''
      return self.g(self.latent_sample(num_samples))

    def f(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
      '''Maps observation x to latent variable z.
      Additionally, computes the log determinant
      of the Jacobian for this transformation.
      Inveres of g.'''
      z, sum_log_abs_det = x, torch.ones(x.size(0)).to(x.device)
      for flow in self.flows:
        z, log_abs_det = flow.f(z)
        sum_log_abs_det += log_abs_det

      return z, sum_log_abs_det

    def g(self, z: torch.Tensor) -> torch.Tensor:
      '''Maps latent variable z to observation x.
      Inverse of f.'''
      with torch.no_grad():
        x = z
        for flow in reversed(self.flows):
               x = flow.g(x)

        return x

    def g_steps(self, z: torch.Tensor) -> List[torch.Tensor]:
      '''Maps latent variable z to observation x
         and stores intermediate results.'''
      xs = [z]
      for flow in reversed(self.flows):
        xs.append(flow.g(xs[-1]))

      return xs

    def log_prob(self, x: torch.Tensor) -> torch.Tensor:
      '''Computes log p(x) using the change of variable formula.'''
      z, log_abs_det = self.f(x)
      return self.latent_log_prob(z) + log_abs_det

    def __len__(self) -> int:
      return len(self.flows)

You can ignore the convenience methods in this class for now, it is in the functions f that the magic happens. Given an observation \(x\) we pass it through the composition of flows while simultaneously computing the log determinant of the Jacobian. With this implementation, training the model can be done by maximizing log_prob. g is simply implemented as the reverse flow, and here we do not need to compute the log determinant of the Jacobian.

And with that we can turn our attention to the individual flows \( f_i \). One design that has proven successful is coupling layers.

Coupling layers

Coupling layers is a fairly popular approach and designs \( f_i \) by first ordering the data dimensions, and then splitting them into two parts \( x_{\leq d}, x_{>d} = x, \) \(1 < d < D \). \( f_i \) is then given by

\begin{equation} f_i(x_j) = \begin{cases} x_j & \quad j \leq d\\ \tau_i(x_j; \theta_i(x_{\leq d})) & \quad j > d \end{cases} \end{equation}

where \( \tau_i \) is often referred to as the transformer, and \( \theta_i \) the conditioner. This is a somewhat unfortunate naming since an unrelated but fairly popular paper also introduced a Transformer, so to avoid confusion we will refer to \( \tau_i \) as the coupling function instead. We are going to pick a specific function class for the coupling function, and then learn the conditioner from data.

When splitting the data \(d\) is commonly chosen to be \(d = \frac{D}{2}\), but how the dimensions are ordered varies. If the data has no natural order, we can always impose our own, and if we know the data has some specific structure we could create specific patterns. For instance, a checkers pattern can be used for image data to capture the locality of pixels. At this point you might be wondering why we are mapping half of the dimensions over the identity function. Indeed, we cannot model particularly interesting densities by only transforming half of the dimensions. The reason we are doing this is to limit the interactions between dimensions and avoid the dreaded cubic complexity. Given this formulation the Jacobian is given by the triangular block matrix

\begin{equation*} \frac{\partial f_i}{\partial x} = \begin{pmatrix} I & 0 \\ \frac{\partial \tau_i}{\partial x_{>d}} & \frac{\partial \tau_i}{\partial x_{>d}} \end{pmatrix} \end{equation*}

with log determinant

\[ \log \left \vert \det \frac{df_i(x)}{dx} \right \vert = \log \prod_{j=d}^D \left \vert \frac{d\tau_i(x_j)}{dx_j} \right \vert = \sum_{j=d}^D \log \left \vert \frac{d\tau_i(x_j)}{dx_j} \right \vert \]

which can be computed in \( \mathcal{O}(D) \). To make sure all dimensions can be transformed we switch places of the dimensions between each coupling layer such that they take turns to be transformed.

Figure 2: Illustration of a coupling layer. Dimensions (x_{leq d}) are used to compute the parameters (theta) of the coupling function and are then mapped over the identity function. Dimensions (x_{&gt; d}) are transformed by (tau), and finally the dimensions switch places with each other so that (x_{leq d}) is transformed in the next flow.

Figure 2: Illustration of a coupling layer. Dimensions (x_{leq d}) are used to compute the parameters (theta) of the coupling function and are then mapped over the identity function. Dimensions (x_{> d}) are transformed by (tau), and finally the dimensions switch places with each other so that (x_{leq d}) is transformed in the next flow.

Since every element in the determinant depends on \( \tau_i \), it is desirable that \( \tau_i \) is fast to evaluate and has a simple derivative, preferably on closed form. While a more expressive \(\tau_i\) makes for a more expressive model overall, we can get away with quite simple transformations. If we need to increase the model’s expressivity we can simply increase the number of flows.

Affine coupling layers

One choice of coupling function is a simple affine transformation \( \tau_i(x) = \exp(s_i) x + t_i \), where \(s_i\) is exponentiated to make it non-zero and guarantee invertibility. This transformation can be computed efficiently and has a simple derivative \( \frac{d}{dx} \left ( \exp(s_i) x + t_i \right ) = \exp(s_i) \) which gives the pleasant log Jacobian determinant

\[ \log \left \vert \det \frac{df_i(x)}{dx} \right \vert = \sum_{j=d}^D \log \left \vert \frac{d\tau_i(x_j)}{dx_j} \right \vert = \sum_{j=d}^D \log \left \vert \exp(s_i^{(j)}) \right \vert = \sum_{j=d}^D s_i^{(j)} \]

where we use superscript to index elements in \(s_i\). We implement this as follows.

class AffineCouplingLayer(nn.Module):

  def __init__(
     self,
     theta: nn.Module,
     split: Callable[[torch.Tensor], Tuple[torch.Tensor, torch.Tensor]]
  ):
    super(AffineCouplingLayer, self).__init__()
    self.theta = theta
    self.split = split

  def f(self, x: torch.Tensor) -> torch.Tensor:
    '''f : x -> z. The inverse of g.'''
    x2, x1 = self.split(x)
    t, s = self.theta(x1)
    z1, z2 = x1, x2 * torch.exp(s) + t
    log_det = s.sum(-1)
    return torch.cat((z1, z2), dim=-1), log_det

  def g(self, z: torch.Tensor) -> torch.Tensor:
    '''g : z -> x. The inverse of f.'''
    z1, z2 = self.split(z)
    t, s = self.theta(z1)
    x1, x2 = z1, (z2 - t) * torch.exp(-s)
    return torch.cat((x2, x1), dim=-1)

The conditioner

The final piece of the model is the conditioner. The rest of the model has been carefully designed to adhere to the math, but we obviously need to fit it to observed data. We do this by learning the conditioner function, which in the case of affine coupling layers learns a function that outputs the vectors \(s_i\) and \(t_i\).

Since the Jacobian only depends on \( \tau_i \), we can make \( \theta_i \) as complex as we want without impacting the computational complexity of computing the Jacobian determinant. And what do we do when we want to learn a differentiable complex function? We stick a neural network in there! Depending on what your data looks like you might want to use a specific architecture, but for now we will settle for a simple fully connected network. Learning the parameters for a single flow is a much simpler problem than learning a function for an entire classification or regression problem as is typically done with neural networks. Hence, the conditioner can be significantly simpler and smaller than popular networks such as ResNet or BERT. We will see this in the follow-up article.

  class Conditioner(nn.Module):

    def __init__(
	self, in_dim: int, out_dim: int,
	num_hidden: int, hidden_dim: int,
	num_params: int
    ):
      super(Conditioner, self).__init__()
      self.input = nn.Linear(in_dim, hidden_dim)
      self.hidden = nn.ModuleList([
	nn.Linear(hidden_dim, hidden_dim)
	for _ in range(num_hidden)
      ])

      self.num_params = num_params
      self.out_dim = out_dim
      self.dims = nn.Linear(hidden_dim, out_dim*num_params)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
      x = F.leaky_relu(self.input(x))
      for h in self.hidden:
	x = F.leaky_relu(h(x))

      batch_params = self.dims(x).reshape(x.size(0), self.out_dim, -1)
      params = batch_params.chunk(self.num_params, dim=-1)
      return [p.squeeze(-1) for p in params]

And with that we have specified the entire model! Compared to a network designed for a regression or classification problem there are quite a few design choices involved in a normalizing flows model. Let’s take a moment to recap the steps we went through.

  • We want to learn the relationship between \(p(x)\) and \(p(z)\)
  • This means learning the bijection \(z = f(x)\)
  • \(f\) is decomposed into \( f = f_n \circ f_{n-1} \circ \dots \circ f_1 \)
  • \(f_i\) is modeled using coupling layers with coupling function \(\tau_i\) and conditioner \(\theta_i\)
  • We let \(\tau_i\) be an affine transformation \(\tau_i(x) = \exp(s_i)x + t_i\)
  • We learn the parameters \(\left (s_i, t_i \right ) = \theta_i \) with a neural network

Since everything is differentiable by design and we have an expression for \(p(x)\) we can train this model by maximum likelihood using SGD just like any other architecture. And with that, it is time to wrap up.

Ending notes

In this article we have covered the core theory of normalizing flows, and seen how to implement a particular flavour of them using affine coupling layers. If you want to dig deeper into normalizing flows for probabilistic modeling I recommend this review paper, which does a good job covering the field. While this article covered techniques that are no longer state-of-the-art, it should give you enough background to dig in to more advanced approaches such as Flow++, FFJORD or Residual flows.

In the follow-up article we will put this implementation to use and fit it to some actual data.

Thank you for reading, I hope you found it interesting. If you have any comments or questions, don’t hesitate to reach out to me.

© Sebastian Callh 2020