May 9, 2021

VAEs as a framework for probabilistic inference

VAEs frequently get compared to GANs, and then dismissed since “GANs produce better samples”. While this might be true for specific VAEs, I think this sells VAEs short. Do I claim that VAEs generate better samples of imaginary celebrities? No (but they are also pretty good). What I mean is that they are qualitatively different and much more general than people give them credit. In this article we are going to consider VAEs as a family of latent variable models and discover that they offer a unified black-box inference framework for probabilistic modelling. The probabilistic programming language Pyro (which sits on top of PyTorch) pretty much embodies this idea, and we will use it to implement VAEs for sevaral different machine learning tasks. You can find the full code here.

VAEs as latent variable models

One common motivation for VAEs you might stumble upon is image generation. We want to generate novel digits, celebrity faces, etc. but if we were to use a deterministic autoencoder, we would not be able to construct a representative latent space to sample from. To remedy this we introduce a Gaussian latent variable, introduce a regularisation term to prevent posterior collapse (the “KL-term”), and voilĂ  we can now produce good samples. I would argue that this derivation is a bit like saying “Being Bayesian is just adding regularization” and completely misses the point of VAEs.

Instead of looking at VAEs as some ad-hoc regularized autoencoder, we should think of them as probabilistic latent variable models. By looking at VAEs through a probabilistic lens from the ground up we will discover that the they are not a specific hour-glass shaped model for image reconstructions (like all those YouTube thumbnails might lead you to believe), but embody a rich class of models for probabilistic inference. To show this, let us first consider an abstract latent variable model \(p(x, z) = p_\theta(x \vert z) p_\theta(z)\) described graphically in Figure 1.

Figure 1: Plate diagram of a latent variable model in all its generality. While (p(z)) could depend on (theta) it is commonly modelled as a standard Gaussian in VAEs, omitting the arrow (theta to z).

Figure 1: Plate diagram of a latent variable model in all its generality. While (p(z)) could depend on (theta) it is commonly modelled as a standard Gaussian in VAEs, omitting the arrow (theta to z).

This model corresponds to the generative story in which we first sample \(z_i \sim p_\theta(z)\), then sample \(x_i \sim p_\theta(x \vert z)\) where \(i = 1,2,\dots,N\) and \(\theta\) are the model parameters. So far this is just an abstract latent variable model; we could assign any density to \(x\) and \(z\), and they could be related in every conceivable way. It could for instance be probabilistic PCA, if we assume Gaussian distributions and a linear relationship. Of course, PCA is a very limited model class, and we might desire a more flexible model. But, as is often the case for probabilistic inference, we quickly run into problems inferring the posterior. While we could use sampling based methods to approximate it, these scale poorly to high-dimensional data. Another common approximation is to use mean-field variational inference, but that would require optimizing over every data point to infer \(z\), which is undesirable for large datasets. This is where the VAE swoops in to save the day by allowing us to use very complex models for the data generating process, and approximate the posterior efficiently.

VAE anatomy

There are two main components of the VAE. The decoder and the the encoder. The decoder describes the generative process of our data, and is typically implemented as a deep neural network which parameterizes the conditional density \(p_\theta(x \vert z)\). Introducing non-linear behaviour between \(x\) and \(z\) makes it possible to model very complex data, but also makes the posterior \(p_\theta(z \vert x)\) intractable. How does the VAE solve this?

By introducing an additional neural network: the encoder (or recognition model), which parameterizes \(q_\phi(z \vert x)\). This allows learning an approximate posterior without resorting to slow sample based methods or optimizing over \(N\) latent variables. While the authors did not use the exact terminology, using a neural network to learn a distribution in this way has become known as amortized inference. This is a very powerful idea that lets us leverage all the progress made on neural networks in the last years for probabilistic inference, and lets us get away with using really complicated models for the data. Putting this together we end up with the slightly modified graphical model in Figure 2.

Figure 2: Plate diagram of a VAE as presented in the original paper. The solid lines denote the generative process and the dotted lines denote the variational posterior (q_phi(z vert x)).

Figure 2: Plate diagram of a VAE as presented in the original paper. The solid lines denote the generative process and the dotted lines denote the variational posterior (q_phi(z vert x)).

Note that the VAE makes basically no assumptions as a latent variable model: we have only assumed a rich enough model class for the posterior to be intractable (which is more like an absence of assumptions), but it makes assumptions on how to perform inference (it assumes amortized inference). This is hardly a constraint, but it is an interesting observation since probabilistic modelling often separates modelling and inference into two distinct problems. The VAE breaks from that convention by absorbing the problem of inference into the model itself.

This is actually all there is to a VAE: a latent variable model fitted using amortized inference. Hence it is much more of a modelling framework than a concrete model. To drive this point home, consider the illustrated “VAE anatomy” in Figure 3. We can plug in any transformation we need into each box to fit our problem and our data. If you are familiar with Generalized linear models (GLMs) this view might look familiar. We are basically dealing with GLMs, except the linear model is replaced with a neural network and the “projections” are linear projections followed by inverse link functions. If that didn’t make any sense to you, don’t worry. We are going to reconstruct an example from the original paper using this framework to make things more concrete.

Figure 3: The anatomy of a VAE with a single latent variable. (f_e) is implemented depending on what kind of data we deal with (i.e. convolutional network for images, RNN for temporal data), and then (P_e) projects onto the domain of the parameters of (p(z vert x)). The same process is applied when decoding: (f_d) is implemented depending on the structure of (z) (typically a fully connected network), and then (P_d) projects onto the domain of the parameters of the observation model (e.g. ([0, 1]) for binary data, (mathbb{R_+}) for rates etc.). Symbols in lower-right corners denote parametrization.

Figure 3: The anatomy of a VAE with a single latent variable. (f_e) is implemented depending on what kind of data we deal with (i.e. convolutional network for images, RNN for temporal data), and then (P_e) projects onto the domain of the parameters of (p(z vert x)). The same process is applied when decoding: (f_d) is implemented depending on the structure of (z) (typically a fully connected network), and then (P_d) projects onto the domain of the parameters of the observation model (e.g. ([0, 1]) for binary data, (mathbb{R_+}) for rates etc.). Symbols in lower-right corners denote parametrization.

Before looking at concrete examples, I want to mention that we are going to brush over all the mathematical details on how to take gradients and fit a VAE (which, ironically, is probably the paper’s largest contribution). Fortunately, the details are readily described in the original paper, and there are countless blog posts out there explaining how to derive the ELBO and how the reparametrization trick works, so I’m sure another explanation wouldn’t add much. Instead, let us see how to apply the idea of the VAE anatomy to reproduce the MNIST example (we will do more interesting things than regular MNIST, I promise).

Reproducing the MNIST experiment

In this example, we assume \(z \sim \mathcal{N}(\mu, \Sigma)\) and \(x \sim \text{Bernoulli}(\pi)\) and implement both \(f_e\) and \(f_d\) as fully connected network. These are easily constructed in PyTorch.

class SimpleMLP(nn.Sequential):
    """Simple network to be used in encoder/decoder."""

    def __init__(
        self,
        in_dim: int,
        out_dim: int,
        hidden_dim: int,
    ):
        super().__init__(
            nn.Linear(in_dim, hidden_dim),
            nn.Softplus(),
            nn.Linear(hidden_dim, out_dim),
        )

We also need to construct \(P_e\) and \(P_d\). These will map hidden representations to \(\mu, \Sigma\) and \(\pi\) respectively. Since we do not assume the dimensionality of \(h_{e}\) or \(h_{d}\) to be the same as the number of distribution parameters, \(P_e\) and \(P_d\) both apply a linear transformation to get the dimensionality right befor projecting onto \(\mathbb{R}\) for \(\mu\), \(\mathbb{R}_+\) for \(\Sigma\) and \([0, 1]\) for \(\pi\). These projections are implemented in the two modules below.

class LocScale(nn.Module):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__()
        self.loc = nn.Linear(in_dim, out_dim)
        self.scale = nn.Sequential(
            nn.Linear(in_dim, out_dim),
            nn.Softplus(),
        )

    def forward(self, x: Tensor) -> Tuple[Tensor, Tensor]:
        return self.loc(x), self.scale(x)


class Binary(nn.Sequential):
    def __init__(self, in_dim: int, out_dim: int):
        super().__init__(
            nn.Linear(in_dim, out_dim),
            nn.Sigmoid(),
        )

Finally we tie everything together in a VAE module.

class VAE(nn.Module):
    """Standard variational encoder model
    for binary image data with a
    Gaussian latent variable."""

    def __init__(
        self,
        x_shape: Tuple[int, int, int],
        z_dim: int,
        hidden_dim: int,
        num_params: int,
    ):
        super().__init__()
        self.z_dim = z_dim
        self.encode = nn.Sequential(
            nn.Flatten(),
            SimpleMLP(
                in_dim=prod(x_shape[-2:]),
                out_dim=hidden_dim,
                hidden_dim=hidden_dim,
            ),
            LocScale(
                in_dim=hidden_dim,
                out_dim=z_dim,
            ),
        )

        self.decode = nn.Sequential(
            SimpleMLP(
                in_dim=z_dim,
                out_dim=hidden_dim,
                hidden_dim=hidden_dim,
            ),
            Binary(
                in_dim=hidden_dim,
                out_dim=num_params,
            ),
        )

You can probably see how this structure mirrors that of Figure 3: the encoder is an MLP followed by projecting onto \(\mu\) and \(\Sigma\), and the decoder is an MLP followed by projecting onto \(\pi\). Having constructed all the necessary transformations, we get to the heart of the model: the random variables. In Pyro we specify the model in a model function and the variational distribution in a guide function.

    # ... continuation of the VAE class
    def model(self, x: Tensor) -> None:
        """Generative model p(x|z)p(z).
        Describes the generative story of our data."""
        pyro.module("vae", self)
        N = x.shape[0]
        with pyro.plate("N", N):

            # sample latent variable z
            z_dim = N, self.z_dim
            p_z = Normal(x.new_zeros(z_dim), 1).to_event(1)
            z = pyro.sample("z", p_z)

            # decode and sample observation
            # validate_args=False to allow for pi in [0, 1]
            pi = self.decode(z)
            p_x = Bernoulli(pi, validate_args=False).to_event(1)
            pyro.sample("x", p_x, obs=x.view(N, -1))


    def guide(self, x: Tensor) -> None:
        """Variational distribution q(z|x).
        Used to infer the latent variables in our model.
        For a VAE this is just a neural network."""
        with pyro.plate("N", x.shape[0]):
            loc, scale = self.encode(x)
            p_z = Norma(loc, scale).to_event(1)
            pyro.sample("z", p_z)

And that concludes constructing our first VAE! It can now be trained on data through maximum likelihood estimation and used to produce the samples like those in Figure 4. For brevity we skip the details of training, but thanks to Pyro’s abstractions it is no more complicated than training a regular neural network. Please see the full source code for details. Having seen how to implement a basic VAE in Pyro, let us now see how it can be generalized to other interesting problem types.

Figure 4: Samples from a VAE trained on MNIST as described in the original VAE paper.

Figure 4: Samples from a VAE trained on MNIST as described in the original VAE paper.

VAEs for semi-supervised learning

VAEs are often thought of as unsupervised models, but once we think of them as probabilistic models, we can see that they generalize to supervised and semi-supervised settings without much trouble. Figure 5 illustrates this version of the VAE in plate notation. We introduce a partially observed latent variable \(y\) for the labels, and use the same amortized inference machinery to infer both \(z\) and \(y\).

Figure 5: Plate diagram over a semi-supervised VAE. Dashed outlines indicate partially observed variables. Parameters (phi), and (theta) are omitted to reduce clutter.

Figure 5: Plate diagram over a semi-supervised VAE. Dashed outlines indicate partially observed variables. Parameters (phi), and (theta) are omitted to reduce clutter.

The changes required to turn our VAE into a semi-supervised VAE are straightforward and only touch the Pyro parts. (Actually, it also affects the training code, but that’s mainly implementation details. See the paper for details and the full code for an implementation). We are going to

  • Add the new random variable \(y\) to the model
  • Add an inference network for \(y\)
  • Update the decoder network to take both \(z\) and \(y\)
  • In the full code, I also changed \(f_e\) to a CNN, just because

For brevity, we only look at the changes to the model here.

def model(self, x: Tensor, y: Optional[Tensor] = None) -> Tensor:
    """Model p(x|z,y)p(y)p(z)"""

    pyro.module("ssvae", self)
    N = x.size(0)
    with pyro.plate("N", N):

        # Sample latent variable
        z_size = torch.Size((N, self.z_dim))
        z_loc, z_scale = x.new_zeros(z_size), x.new_ones(z_size)
        p_z = Normal(z_loc, z_scale).to_event(1)
        z = pyro.sample("z", p_z)

        # Sample class label
        y_size = torch.Size((batch_size, self.num_classes))
        alpha = x.new_ones(y_size) / self.num_classes
        p_y = Categorical(alpha)
        y_sample = pyro.sample("y", p_y, obs=y)

        # Sample the observation
        # validate_args=False to allow for pi in [0, 1]
        pi = self.decode_x(z, y_sample)
        p_x = Bernoulli(pi, validate_args=False).to_event(1)
        pyro.sample("x", p_x, obs=x.view(-1, loc.size(-1)))


def guide(self, x: Tensor, y: Optional[Tensor] = None):
    """Variational distribution p(z|y,x)p(y|x)"""

    with pyro.plate("N", x.shape[0]):
        if y is None:
            pi = self.encode_y(x)
            p_y = Categorical(pi)
            y = pyro.sample("y", p_y)

        loc, scale = self.encode_z(x, y)
        p_z = Normal(loc, scale).to_event(1)
        pyro.sample("z", p_z)

This will require a larger \(\phi\) and \(\theta\), but we can share all the parameters in the network body in \(f_d\), so it is not as bad as one might first think. And being able to deal with problems in a semi-supervised fashion is really powerful. This model achieves \(0.96\) accuracy on MNIST using less than \(10\)% of the data, and allows drawing conditional posterior samples as seen in In Figure 6. While MNIST is indeed a solved problem, the interesting takeaway here is that we can transition from an unlabeled regime to a semi-supervised regime with only a few adjustements to our model.

VAEs for multiple observed variables

So far we have seen VAEs applied to image data, where we treat all the pixels as a single multivariate random variable. This makes sense since images are high dimensional, complicated things and it’s difficult to assign additional structure. But remember that the structure of our data only affects how we implement \(f_e\); the VAE framework as a whole can deal with any kind of data. Before we wrap up, let us showcase this with some tabular data where each column, or feature, has a particular meaning. What if we want to treat features of our dataset as separate random variables? Well, nothing is stopping us! Consider the dataset \(X \in \mathcal{R}^3\)

\begin{equation} \begin{split} x_1 & \sim \Gamma(k_0, \theta_0) \\ x_2 & \sim \mathcal{N}(\mu_0, \sigma_0) \\ x_3 & \sim \mathcal{N}(x_2, x_1). \\ \end{split} \end{equation}

While not the most complicated dataset, there is meaningful structure in it, which needs to be taken into account to fit the data. We trust the latent variable \(z\) to capture this structure, and introduce additional observed random variables, illustrated in Figure 6.

Figure 6: A VAE is able to learn densities over multiple random variables. In this case we have three observed variables, each with different densities. Parameters (phi), and (theta) are omitted to reduce clutter.

Figure 6: A VAE is able to learn densities over multiple random variables. In this case we have three observed variables, each with different densities. Parameters (phi), and (theta) are omitted to reduce clutter.

We can implement a VAE for this dataset with minor changes to our observation model and to the decoder. Instead of sampling from a single multivariate distribution, we simply sample from the individual univariate distributions.

def model(self, x: Tensor) -> None:
    """generative model p(x|z)p(z)"""
    pyro.module("vae", self)
    N = x.shape[0]
    with pyro.plate("N", batch_dim):
        # sample latent variable z
        z_dim = N, self.z_dim
        z_mu, z_sigma = x.new_zeros(z_dim), x.new_ones(z_dim)
        z = pyro.sample("z", Normal(z_mu, z_sigma).to_event(1))

        # decode and sample observation
        ps = self.decode(z)
        pyro.sample(f"x_1", Gamma(ps["k"], ps["theta"]), obs=x[:, 0])
        pyro.sample(f"x_2", Normal(ps["mu_2"], ps["sigma_2"]), obs=x[:, 1])
        pyro.sample(f"x_3", Normal(ps["mu_2"], ps["sigma_3"]), obs=x[:, 2])

Of course, we also need to update \(P_d\).

class VAE(nn.Module):
    # other initializations omitted for brevity
    def __init__(...):
        self.decode = nn.Sequential(
            SimpleMLP(...),
            Params(),
        )

class Params(nn.Module):
    def forward(self, x: Tensor) -> Dict[str, Tensor]:
        return {
            "k": F.softplus(x[:, 0]),
            "theta": F.softplus(x[:, 1]),
            "loc": x[:, 2],
            "scale": F.softplus(x[:, 3]),
            "mu": x[:, 4],
            "sigma": F.softplus(x[:, 5]),
        }

Observe how we still only use a single network for \(f_d\), and how the rest of the VAE machinery remains intact; we do not need to touch the encoder or the inference algorithm. The model posterior successfully fits the true density as illustrated in Figure 7.

Figure 7: A VAE is able to fit to data with multiple observed random variables, related through latent structure.

Figure 7: A VAE is able to fit to data with multiple observed random variables, related through latent structure.

Closing words

I hope this article gave you some new appreciation of VAEs. They are much better viewed as a framework for fitting latent variable models with intractable posteriors, instead of regularized autoencoders. VAEs are fit through amortized inference, which lets us use the same procedures as when fitting a neural network, with mini-batched data and GPU accelerated optimization. Thanks to tools like Pyro, VAEs are very easy to implement and tailor to your specific task, be it unsupervised, semi-supervised, or fully supervised. If you’d like to dig a bit deeper, the authors of the original VAE paper also wrote an introduction paper on the topic, which is a good place to start.

Thank you for reading, I hope you found it interesting!

© Sebastian Callh 2020