May 14, 2020

Bayesian inference with Stochastic Gradient Langevin Dynamics

Modern machine learning algorithms can scale to enormous datasets and reach superhuman accuracy on specific tasks. Yet, they are largely incapable of answering “I don’t know” when queried with new data. Taking a Bayesian approach to learning lets models be uncertain about their predictions, but classical Bayesian methods do not scale to modern settings. In this post we are going to use Julia to explore Stochastic Gradient Langevin Dynamics (SGLD), an algorithm which makes it possible to apply Bayesian learning to deep learning models and still train them on a GPU with mini-batched data.

Bayesian learning

A lot of digital ink has been spilled arguing for Bayesian learning. Particularly in domains where knowing model certainty is important, such as in the medical domain and for autonomous driving. Unfortunately, Bayesian learning comes with some challenges, in particular it is difficult to scale Bayesian learning to the size of modern deep learning models. One reason for this is that Bayesian inference rely heavily on sampling algorithms, which suffer under the curse of dimensionality. Additionally, sampling from the posterior requires evaluating the likelihood function, which requires a whole pass over the data. This is computationally expensive if the dataset is large.

Consequently, Bayesian deep learning is an active area of research, and several different approaches have been proposed. Some include treating dropout as approximate Bayesian inference, turning the inference problem to an optimization problem through variational inference, and treating the optimization trajectory as an MCMC chain They all make different trade-offs and are all worth reading about, but in this blog post we are going to focus on the third approach.

Gradient descent

Let’s start off by looking at gradient descent for maximum likelihood estimation. We will use it to build up to SGLD. Recall that when training a model using standard gradient descent the model parameters \(\theta\) are initialized randomly, and updated by iterating

\[ \theta_{t+1} = \theta_t - \Delta \theta_t \] where \[ \Delta \theta_t = \eta \left( \nabla \log p(\theta_t) + \nabla \log p(x \vert \theta_t) \right ) \]

until convergence. \(\eta\) denotes the learning rate, which is the only hyper parameter, \( p(\theta_t) \) denotes our prior over \(\theta\) (commonly referred to as regularization term) and \( p(x \vert \theta_t) \) the likelihood (commonly referred to as loss function). Gradient descent is appealing not only because its simplicity but because of its scalability. Modern machine learning models are commonly trained on huge datasets which do not fit into memory, but this poses no problem since the gradient can be estimated using mini-batches. Using mini-batches of size \(n\) gives us the famous SGD algorithm with parameter updates rule

\[ \Delta \theta_t = \eta \left( \nabla \log p(\theta_t) + \frac{1}{n} \sum_i^n \nabla \log p(x_i \vert \theta_t) \right ). \]

Stochastic Gradient Langevin Dynamics

The authors of the Bayesian Learning via Stochastic Gradient Langevin Dynamics paper show that we can interpret the optimization trajectory of SGD as a Markov chain with an equilibrium distribution over the posterior over \(\theta\). This might sound intimidating, but the practical implications of this result are surprisingly simple: We train the model using regular SGD, but add some Gaussian noise to each step. We then let the induced noise and learning rate decay towards \(0\) as training time \(t\) increases. The intuitive explanation to why this works is that it lets the optimization process find a local mode, but it will never actually converge due to the induced noise. However, it is also prevented from leaving the mode due to the decayed learning rate, leaving it random walking the mode. We refer to the process before settling in a local mode as the optimization phase and process after that as the sampling phase.

By collecting the parameters found during the sampling phase we can approximate the posterior mode and enjoy the benefits of Bayesian learning, while at the same time performing the optimization using mini-batched data on powerful GPUs. This lets SGLD overcome the curse of dimensionality and the need to evaluate the likelihood function over all the data (one mini-batch is enough), while still being able to acquire posterior samples. The algorithm really blurs the line between the two worlds of point-estimate optimization and Bayesian inference.

Let’s formalize this intuition. As presented in the paper (in the authors notation) the update rule for SGLD is given by

\[ \Delta \theta_t = \frac{\epsilon_t}{2} \ \left( \nabla \log p(\theta_t) + \frac{1}{n} \sum_i^n \nabla \log p(x_i \vert \theta_t) \right ) + \eta_t. \] where \(\eta_t \sim \mathcal{N}(0, \epsilon_t) \) and \( \epsilon_t \) decay according to \( \epsilon_t = a(b + t)^{-\gamma} \) (Please note how \(\eta_t\) in this notation denotes the induced noise and not the learning rate.) Unsurprisingly, it looks very similar to SGD. However, we now pay the price of having three hyper parameters \(a\), \(b\) and \(\gamma\) instead of just one. Having fiddled around with these, I can tell you that the model performance is very sensitive to the values we pick, so we best choose them intelligently. Plotting them helps us get an intuition for how they affect \(\epsilon_t\).

Figure 1: The effect of hyper parameters on (epsilon_t). We can think of (a) as the initial learning rate and (gamma) as the speed at which it decays. (gamma) also largely controls the asymptotic learning rate, which must be small enough to transition to the sampling phase. (b) decides where in the curve learning starts. For instance we could set (b = 100) to avoid the super high learning rates in the start of the curve.

Figure 1: The effect of hyper parameters on (epsilon_t). We can think of (a) as the initial learning rate and (gamma) as the speed at which it decays. (gamma) also largely controls the asymptotic learning rate, which must be small enough to transition to the sampling phase. (b) decides where in the curve learning starts. For instance we could set (b = 100) to avoid the super high learning rates in the start of the curve.

Use case: The Default dataset

To see how SGLD works we are going to train a logistic regression model. While the algorithm is designed to solve the problem of scaling to large models and datasets, we are mostly interested in understanding SGLD itself, which is easier with a small problem. In light of that, we will use the Default dataset from the ISLR package. It is a simple toy dataset for modeling whether a customer is going to default on their credit card debt or not. The data requires minimal pre-processing: we have to encode categorical variables as numerical values instead of string labels.

    using RDatasets, Statistics
    using Plots, Random
    Random.seed!(0)

    data = dataset("ISLR", "Default")
    todigit(x) = x == "Yes" ? 1.0 : 0.0
    data[!,:Default] = map(todigit, data[:,:Default])
    data[!,:Student] = map(todigit, data[:,:Student])
    println("Data frame contains $(size(data, 1)) observations")
    first(data, 5)
Data frame contains 10000 observations

To help convergence we will also standardize the data. Finally, 30% of the data is set aside for testing using stratified splitting.

  using MLDataUtils: shuffleobs, stratifiedobs, rescale!
  using Flux: gpu

  target = :Default
  numerics = [:Balance, :Income]
  features = [:Student, :Balance, :Income]
  train, test = shuffleobs(data) |>
      d -> stratifiedobs(first, d, p=0.7)

  for feature in numerics
      μ, σ = rescale!(train[!, feature], obsdim=1)
      rescale!(test[!, feature], μ, σ, obsdim=1)
  end

  prep_X(x) = Matrix(x)' |> gpu
  prep_y(y) = reshape(y, 1, :) |> gpu
  train_X, test_X = prep_X.((train[:, features], test[:, features]))
  train_y, test_y = prep_y.((train[:, target], test[:, target]))

While I said that one of the major selling point of SGLD is that we can do Bayesian learning using mini-batched data, we are going to use the whole dataset to get a better feeling for how the algorithm works and how it transitions from the optimization phase to the sampling phase. We are also not going to use any regularization since it is not necessary for this problem. However, both mini-batching and regularization can be implemented exactly the same way you would using SGD.

Frequentist logistic regression

Let us begin with training a “regular” logistic regression with SGD for reference. While we could code the gradients for this particular problem by hand, we will use the Flux package to do it for us. In Flux we can implement logistic regression as a dense layer with a sigmoid activation function. Let’s create a function for fitting such a model to our data given an update rule and a number of steps (or epochs).

  using Flux, CuArrays
  using Random: seed!
  CuArrays.allowscalar(false)

  function train_logreg(;steps, update)
      seed!(1)

      paramvec(θ) = reduce(hcat, cpu(θ))
      model = Dense(length(features), 1, sigmoid) |> gpu
      θ = Flux.params(model)
      θ₀ = paramvec(θ)

      predict(x; thres = .5) = model(x) .> thres
      accuracy(x, y) = mean(cpu(predict(x)) .== cpu(y))

      loss(x, y) = mean(Flux.binarycrossentropy.(model(x), y))
      trainloss() = loss(train_X, train_y)
      testloss() = loss(test_X, test_y)

      trainlosses = [cpu(trainloss()); zeros(steps)]
      testlosses = [cpu(testloss()); zeros(steps)]
      weights = [cpu(θ₀); zeros(steps, length(θ₀))]
      for t in 1:steps

	  # ∇L denotes gradient with respect to loss
	  ∇L = gradient(trainloss, θ)
	  foreach(θᵢ -> update(∇L, θᵢ, t), θ)

	  # Bookkeeping
	  weights[t+1, :] = cpu(paramvec(θ))
	  trainlosses[t+1] = cpu(trainloss())
	  testlosses[t+1] = cpu(testloss())
      end

      println("Final parameters are $(paramvec(θ))")
      println("Test accuracy is $(accuracy(test_X, test_y))")

      model, weights, trainlosses, testlosses
  end

With this function training a model is straight-forward. We simply call the function with an implementation of the SGD update rule and a number of steps. Since we do not use any mini-batches or regularization the update rule we have to implement is simply \( \Delta \theta_t = \eta \nabla \log p(x\vert \theta_t). \) (Flux comes with built-in optimizers such as SGD, but to be able to compare with SGLD we create our own).

  sgd(∇L, θᵢ, t, η = 2) = begin
      Δθᵢ = η.*∇L[θᵢ]
      θᵢ .-= Δθᵢ
  end

  results = train_logreg(steps = 1000, update = sgd)
  model, weights, trainlosses, testlosses = results;
Final parameters are Float32[-0.8546615 2.8704085 0.010423469 -6.0955453]
Test accuracy is 0.971

Our model converges with high accuracy. Since we did not use mini-batching the optimization is very smooth. Should we have used mini-batches we would have introduces some noise to the training process, but the average training trajectory would look the same.

Figure 2: SGD marches straight toward the closest local minima; initially rapidly and then slower as the optimization landscape flattens. The problem is an easy one, and both test and training loss go down quickly.

Figure 2: SGD marches straight toward the closest local minima; initially rapidly and then slower as the optimization landscape flattens. The problem is an easy one, and both test and training loss go down quickly.

Bayesian logistic regression

Finally we arrive at the implementation of SGLD. Remember that we do not use mini-batching or regularization, which simplifies the SGLD update rule to \( \Delta \theta_t = \frac{\epsilon_t}{2} \nabla \log p(x \vert \theta_t) + \eta_t. \) However, we still have to pick our hyper parameters. As we will see, we also have to train for much longer for the optimization to converge.

  sgld(∇L, θᵢ, t, a = 10.0, b = 1000, γ = 0.9) = begin
      ϵ = a*(b + t)^-γ
      η = ϵ.*gpu(randn(size(θᵢ)))
      Δθᵢ = .5ϵ*∇L[θᵢ] + η
      θᵢ .-= Δθᵢ
  end

  results = train_logreg(steps = 20000, update = sgld)
  model, weights, trainlosses, testlosses = results;
Final parameters are Float32[-0.2760693 0.6358937 0.018183365 -2.781438]
Test accuracy is 0.9666666666666667

Observing the optimization process shows us how different SGLD behaves compared to SGD. The process starts off very noisy, but as \(\epsilon_t\) decays SGLD converge to its stationary distribution and enters the sampling phase. We can see that convergence takes quite a bit longer than for SGD. Since we are not content with finding a single parametrization, we also let the algorithm run for several steps after convergence to give it time to explore the mode. If you are familiar with MCMC, you probably notice the similarity between SGLD optimization phase and the MCMC burn-in phase.

We can now extract the last \( n=2000 \) samples to approximate the posterior mode. Using these we can plot the marginal posterior of the model parameters.

Figure 3: The maginal posterior distributions. Not all of them are unimodal, but they are quite concentrated.

Figure 3: The maginal posterior distributions. Not all of them are unimodal, but they are quite concentrated.

We can of course also make predictions. Lets evaluate the accuracy of the model on our test set. Using our posterior samples we approximate the posterior predictive distribution

\[ p(\tilde{y} \vert Y) \approx \frac{1}{n}\sum_i^n p(\tilde{y} \vert \theta_i, Y) p(\theta_i \vert Y) \]

for a new observation \( \hat{y} \) and previously seen data \(Y\), and predict that the person will default if \( E[\hat{y} \vert Y] > 0.5 \).

  function posterior_predictive(x, w)
    # add column of one to emulate addition of bias
    I = gpu(ones(size(x, 2)))
    x, w = gpu.((x, w))
    py = sigmoid.(w*[ x; I'])
  end

  function predict(x, w; thresh)
    ŷ = mean(posterior_predictive(x, w), dims=2) .> thresh
  end
  w = samples
  ŷ = predict(test_X, w, thresh=.5)
  accuracy = mean(ŷ .== test_y)
0.9666666666666667

The model display about the same performance when trained with SGLD as with SGD. Great! While training the model took more time, we did not lose out on any performance, and our model is now capable of the amazing feat of answering “I don’t know” by giving us a high-entropy distribution when presented with a weird example.

      x = [1., 8, -60.0]
      py = cpu(posterior_predictive(x, w))
      histogram(py, title="Posterior probability of defaulting", size=(900, 500),
		legend=nothing, xlabel="Probability", ylabel="Density", normalize=true)
Figure 4: The Bayesian model captures the uncertainty in predictions in the posterior predictive distribution. Returning a high-entropy distribution is the models way of saying “I don’t know”.

Figure 4: The Bayesian model captures the uncertainty in predictions in the posterior predictive distribution. Returning a high-entropy distribution is the models way of saying “I don’t know”.

Final words

This has been an introduction to SGLD where we have seen how it works on a simple dataset and how it allows us to perform Bayesian learning. While I personally find the idea of SGLD super cool, it is not without issues. It converges significantly slower than SGD during optimization, and mixes slowly during the sampling phase. However, as the authors themselves put it: “the advantage of the method is its convenience”; we can perform optimization followed by posterior sampling in a single algorithm. There has also been follow up work such as Stochastic gradient Fisher scoring and SGMCMC that aim to improve mixing.

You may have noticed that SGD and SGLD did not converge to the same local mode, which highlights another concern: SGLD only approximates a single posterior mode. When training complex models it is quite likely that there are other posterior modes which explain the data, which we would like our samples to reflect. Cyclical SGMCMC attempts to improve upon this issue.

Finally I would like to mention two details I have glossed over: How many steps to train, and when to start collecting samples? The authors provide a metric for when the injected noise dominates and the algorithm has entered the sampling phase, and also estimate how long you need to sample to explore the entire mode. Since these details are a bit technical and this post is long enough, I refer to their paper

Thank you for reading! I hope you found it interesting.

© Sebastian Callh 2020