Implementing Variational Autoencoders in Keras: Beyond the Quickstart Tutorial
Draft
Please do not share or link.
Keras is awesome. It is a very welldesigned library that clearly abides by its guiding principles of modularity and extensibility, and allows us to easily assemble powerful, complex models from primitive building blocks. This has been demonstrated in numerous blog posts and tutorials, such as the excellent tutorial on Building Autoencoders in Keras. As the name suggests, that tutorial provides examples of how to implement various kinds of autoencoders in Keras, including the variational autoencoder (VAE) [1].
Like all autoencoders, the variational autoencoder is primarily used for unsupervised learning of hidden representations. However, they are fundamentally different to your usual neural networkbased autoencoder in that they approach the problem from a probabilistic perspective. They specify a joint distribution over the observed and latent variables, and approximate the intractable posterior conditional density over latent variables with variational inference, using an inference network [2] [3] (or more classically, a recognition model [4]) to amortize the cost of inference.
While the examples in the aforementioned tutorial do well to showcase the versatility of Keras on a wide range of autoencoder model architectures, its implementation of the variational autoencoder doesn't properly take advantage of Keras' modular design, making it difficult to generalize and extend in important ways. As we will see, it relies on implementing custom layers and constructs that are restricted to a specific instance of variational autoencoders. This is a shame because when combined, Keras' building blocks are powerful enough to encapsulate most variants of the variational autoencoder and more generally, recognitiongenerative model combinations for which the generative model belongs to a large family of deep latent Gaussian models (DLGMs) [5].
The goal of this post is to propose a clean and elegant alternative implementation that takes better advantage of Keras' modular design. It is not intended as tutorial on variational autoencoders [*]. Rather, we study variational autoencoders as a specific case of variational inference in deep latent Gaussian models with inference networks, and demonstrate how we can use Keras to implement them in a modular fashion such that they can be easily adapted to approximate inference in various common problems with different (nonGaussian) likelihoods, such as classification with Bayesian logistic / softmax regression.
This first post will lay the groundwork for a series of future posts that explore ways to extend this basic modular framework to implement the more powerful methods proposed in the latest research, such as the normalizing flows for building richer posterior approximations [6], importance weighted autoencoders [7], the Gumbelsoftmax trick for inference in discrete latent variables [8], and even the most recent GANbased densityratio estimation techniques for likelihoodfree inference [9] [10].
Model specification
First, it is important to understand that the variational autoencoder is not a way to train generative models. Rather, the generative model is a component of the variational autoencoder and is, in general, a deep latent Gaussian model. In particular, let \(\mathbf{x}\) be a local observed variable and \(\mathbf{z}\) its corresponding local latent variable, with joint distribution
In Bayesian modelling, we assume the distribution of observed variables to be governed by the latent variables. Latent variables are drawn from a prior density \(p(\mathbf{z})\) and related to the observations though the likelihood \(p_{\theta}(\mathbf{x}  \mathbf{z})\). Deep latent Gaussian models (DLGMs) are a general class of models where the observed variable is governed by a hierarchy of latent variables, and the latent variables at each level of the hierarchy are Gaussian a priori [5].
In a typical instance of the variational autoencoder, we have only a single layer of latent variables with a Normal prior distribution,
Now, each local latent variable is related to its corresponding observation through the likelihood \(p_{\theta}(\mathbf{x}  \mathbf{z})\), which can be viewed as a probabilistic decoder. Given a hidden lowerdimensional representation (or "code") \(\mathbf{z}\), it "decodes" it into a distribution over the observation \(\mathbf{x}\).
Decoder
In this example, we define \(p_{\theta}(\mathbf{x}  \mathbf{z})\) to be a multivariate Bernoulli whose probabilities are computed from \(\mathbf{z}\) using a fullyconnected neural network with a single hidden layer,
where \(\sigma\) is the logistic sigmoid function, \(h\) is some nonlinearity, and the model parameters \(\theta = \{ \mathbf{W}_1, \mathbf{W}_2, \mathbf{b}_1, \mathbf{b}_1 \}\) consist of the weights and biases of this neural network.
It is straightforward to implement this in Keras with the Sequential model API:
decoder = Sequential([ Dense(intermediate_dim, input_dim=latent_dim, activation='relu'), Dense(original_dim, activation='sigmoid') ])
You can view a summary of the model parameters \(\theta\) by calling decoder.summary(). Additionally, you can produce a highlevel diagram of the network architecture, and optionally the input and output shapes of each layer using plot_model from the keras.utils.vis_utils module. Although our architecture is about as simple as it gets, it is included in the figure below as an example of what the diagrams look like.
Note that by fixing \(\mathbf{W}_1\), \(\mathbf{b}_1\) and \(h\) to be the identity matrix, the zero vector, and the identity function, respectively (or equivalently dropping the first Dense layer in the snippet above altogether), we recover logistic factor analysis. With similarly minor modifications, we can recover other members from the family of DLGMs, which include nonlinear factor analysis, nonlinear Gaussian belief networks, sigmoid belief networks, and many others [5].
Inference
Having specified the generative process, we would now like to perform inference on the latent variables and model parameters, \(\mathbf{z}\) and \(\theta\), respectively. In particular, our goal is to compute the posterior \(p_{\theta}(\mathbf{z}  \mathbf{x})\), the conditional density of the latent variable \(\mathbf{z}\) given observed variable \(\mathbf{x}\). Additionally, we wish to optimize the model parameters \(\theta\) with respect to the marginal likelihood \(p_{\theta}(\mathbf{x})\). Both depend on the marginal likelihood, which requires marginalizing out the latent variables \(\mathbf{z}\). In general, this is computational intractable, requiring exponential time to compute. Or, it is analytically intractable and cannot be evaluated in closedform, as it is in our case where the Gaussian prior is nonconjugate to the Bernoulli likelihood.
To circumvent this intractability we turn to variational inference, which formulates inference as an optimization problem. It seeks an approximate posterior \(q_{\phi}(\mathbf{z}  \mathbf{x})\) with variational parameters \(\phi\) closest in KullbackLeibler (KL) divergence to the true posterior.
With the luck we've had so far, it shouldn't come as a surprise anymore that this too is intractable. It also depends on the log marginal likelihood, whose intractability is the reason we appealed to approximate inference in the first place. Instead, we maximize an alternative objective function, the evidence lower bound (ELBO), which is expressed as
Importantly, the ELBO is a lower bound to the log marginal likelihood. Therefore, maximizing it with respect to the model parameters \(\theta\) approximately maximizes the log marginal likelihood. Additionally, maximizing it with respect variational parameter \(\phi\) can be shown to minimize \(\mathrm{KL} [q_{\phi}(\mathbf{z}  \mathbf{x}) \ p_{\theta}(\mathbf{z}  \mathbf{x}) ]\). Also, it turns out that the KL divergence determines the tightness of the lower bound, where we have equality iff the KL divergence is zero, which happens iff \(q_{\phi}(\mathbf{z}  \mathbf{x}) = p_{\theta}(\mathbf{z}  \mathbf{x})\). Hence, simultaneously maximizing it with respect to \(\theta\) and \(\phi\) gets us two birds with one stone.
Encoder
Probabilistic encoder, inference network due to ..., recognition network, due to ...
Every local latent variable x_i corresponding to observed variable x_i has its own set of local variational parameters phi_i. For example, q_{phi_i}(z_i) = N(z_i  mu_i, diag(sigma_i^2)), with variational parameters phi_i = {mu_i, sigma_i}.
The number of local variational parameters grows with the size of the observed data. Furthermore, a new set of parameters must be optimized for unseen test data points. We amortize the cost of inference by introducing an inference network which outputs the local variational parameters phi_i given x_i as input. This approximation allows statistical strength to be shared across observed datapoints and also generalize to unseen test points.
Continuing with the example, we have q_{phi}(z_i  x_i ) = N(z_i  mu_{phi}(x_i), diag(sigma_{phi}(x_i)^2)), with variational parameters phi_i = {mu_i, sigma_i}.
In the specific case of autoencoders, the network that maps latent code
More the general case of amortized variational inference, this is known as a recognition model, or an inference network.
x = Input(shape=(original_dim,)) h = Dense(intermediate_dim, activation='relu')(x) z_mu = Dense(latent_dim)(h) z_log_var = Dense(latent_dim)(h) z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)
figure here
Reparameterization using Merge Layers
To perform gradientbased optimization of ELBO, we require its gradients with respect to the variational parameters \(\phi\), which is generally intractable. Currently, the dominant approach for circumventing this is by Monte Carlo (MC) estimation of the gradients. There are a several estimators based on different variance reduction techniques. However, for continuous latent variables, the reparameterization gradients can be shown to have the lowest variance among competing estimators.
The ELBO can be written as an expectation of a multivariate function \(f(\mathbf{x}, \mathbf{z}) = \log p_{\theta}(\mathbf{x} , \mathbf{z})  \log q_{\phi}(\mathbf{z}  \mathbf{x})\) over distribution \(q_{\phi}(\mathbf{z}  \mathbf{x})\).
Specifying gives us the gradient of the ELBO above.
Assume z_mu and z_sigma are the outputs of some layers. Then, using Merge Layers, Add and Multiply:
eps = Input(shape=(latent_dim,)) z_eps = Multiply()([z_sigma, eps]) z = Add()([z_mu, z_eps])
Lambda layer, which simultaneously draws samples from a hardcoded base distribution and performs reparameterization. This implementation achieves a more appropriate level of modularity and abstraction. It's makes it clear that each of these atomic building blocks are themselves deterministic transformations which together make up a deterministic transformation. The source of stochasticity comes from the input, which we are able to tweak at test time. Gumbelsoftmax trick.
eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0], latent_dim)))
For the sake of illustration, we've fixed sigma and mu as Input layers. That's why it says InputLayer next to it. In reality, it will be the output layer of a network. We specify \(\mathbf{\mu}_{\phi}(\mathbf{x})\) and \(\mathbf{\sigma}_{\phi}(\mathbf{x})\) now.
KL Divergence
latent space regularization
class KLDivergenceLayer(Layer): """ Identity transform layer that adds KL divergence to the final model loss. """ def __init__(self, *args, **kwargs): self.is_placeholder = True super(KLDivergenceLayer, self).__init__(*args, **kwargs) def call(self, inputs): mu, log_var = inputs kl_batch =  .5 * K.sum(1 + log_var  K.square(mu)  K.exp(log_var), axis=1) self.add_loss(K.mean(kl_batch), inputs=inputs) return inputs
z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
by itself, it will learn to ignore the input and map all outputs to 0. It is only when we tack on the decoder that the reconstruction likelihood is introduced. Only then will we reconcile the likelihood / observed data with our prior to form the posterior over latent codes.
At this stage we could specify prob_encoder = Model(inputs=x, outputs=[z_mu, z_sigma]) and compile it with something like prob_encoder.compile(optimizer='rmsprop`, loss=None). When we fit it, it would trivially map all inputs to 0 and 1, thus learning the prior distribution.
inputs mu and log_var are of shape (batch_size, latent_dim) the loss we add should be scalar. this is unlike loss function specified in model compile which should returns loss vector of shape (batch_size,) since it requires loss for each datapoint in the batch for sample weighting.
Putting it all together
x = Input(shape=(original_dim,)) h = Dense(intermediate_dim, activation='relu')(x) z_mu = Dense(latent_dim)(h) z_log_var = Dense(latent_dim)(h) z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var]) z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var) eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0], latent_dim))) z_eps = Multiply()([z_sigma, eps]) z = Add()([z_mu, z_eps]) decoder = Sequential([ Dense(intermediate_dim, input_dim=latent_dim, activation='relu'), Dense(original_dim, activation='sigmoid') ]) x_mean = decoder(z)
vae = Model(inputs=[x, eps], outputs=x_mean) vae.compile(optimizer='rmsprop', loss=nll)
When combined endtoend, the inference network and the deep latent Gaussian model can be seen as having an autoencoder structure. Indeed, this general structure contains the variational autoencoder as a special case, and more traditionally, the Helmholtz machine. Even more generally, we can use this structure to perform amortized variational inference in complex generative models for a wide array of supervised, unsupervised and semisupervised tasks.
The point of this tutorial is to illustrate the general framework for performing amortized variational inference using Keras, treating the inference network (approximate posterior) and the generative network (likelihood) as blackboxes. What we've used for the encoder and decoder each with a single hidden fullconnected layer is perhaps the minimal viable architecture. In the examples directory, Keras provides a more sophisticated variational autoencoder with deconvolutional layers. The architecture definitions can be trivially copypasted here without need to modify anything else.
Parameter Learning
We load the training data as usual. Now the vae is explicitly specified with random noise source as an auxiliary input. This allows to easily control the base distribution \(p(\mathbf{\epsilon})\) and also how we draw Monte Carlo samples of \(\mathbf{z}\) for each datapoint \(\mathbf{x}\). Usually we just stick with a simple isotropic Gaussian distribution and draw a different MC sample for each datapoint.
(x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(1, original_dim) / 255. x_test = x_test.reshape(1, original_dim) / 255.
Model fitting feels less intuitive. The vae is compiled with loss=None explicitly specified which raises a warning. When fit is called, the targets argument is left unspecified, and the reconstruction loss is optimized through the CustomLayer. This mapping from mathematical problem formulation to code implementation appears more natural and straightforward. It's easy to understand at a glance from our call to the fit method that we're training a probabilistic autoencoder.
vae.fit(x_train, x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test, x_test))
Personally, I prefer this view since the all sources of stochasticity emanate from the inputs to the model.
Model evaluation
encoder = Model(x, z_mu) # display a 2D plot of the digit classes in the latent space z_test = encoder.predict(x_test, batch_size=batch_size) plt.figure(figsize=(6, 6)) plt.scatter(z_test[:, 0], z_test[:, 1], c=y_test, alpha=.4, s=3**2, cmap='viridis') plt.colorbar() plt.show()
# display a 2D manifold of the digits n = 15 # figure with 15x15 digits digit_size = 28 # linearly spaced coordinates on the unit square were transformed # through the inverse CDF (ppf) of the Gaussian to produce values # of the latent variables z, since the prior of the latent space # is Gaussian u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, n), np.linspace(0.05, 0.95, n))) z_grid = norm.ppf(u_grid) x_decoded = decoder.predict(z_grid.reshape(n*n, 2)) x_decoded = x_decoded.reshape(n, n, digit_size, digit_size) plt.figure(figsize=(10, 10)) plt.imshow(np.block(list(map(list, x_decoded))), cmap='gray') plt.show()
Recap
 Demonstration of Sequential and functional Model API
 Custom auxiliary layers that augments the model loss
 Fixing input to source of stochasticity
 Reparameterization using Merge layers
What's next
Normalizing flows
We illustrate how to employ the simple GumbelSoftmax reparameterization to build a Categorical VAE with discrete latent variables.
We can easily extend KLDivergenceLayer to use an auxiliary density ratio estimator function, instead of evaluating the KL divergence in the closedform expression above. This relaxes the requirement on approximate posterior \(q(\mathbf{z}\mathbf{x})\) (and incidentally, also prior \(p(\mathbf{z})\)) to yield tractable densities, at the cost of maximizing a cruder estimate of the ELBO. This is known as Adversarial Variational Bayes [9], and is an important line of recent research that extends the applicability of variational inference to arbitrarily expressive implicit probabilistic models [10].
Footnotes
[*] 
For a complete treatment of variational autoencoders, and variational inference in general, I highly recommend:

References
[1]  D. P. Kingma and M. Welling, "AutoEncoding Variational Bayes," in Proceedings of the 2nd International Conference on Learning Representations (ICLR), 2014. 
[2]  Edward tutorial on Inference Networks 
[3]  Section "Recognition models and amortised inference" in Shakir's blog post. 
[4]  Dayan, P., Hinton, G. E., Neal, R. M., & Zemel, R. S. (1995). The Helmholtz machine. Neural Computation, 7(5), 889–904. http://doi.org/10.1162/neco.1995.7.5.889 
[5]  (1, 2, 3) Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). "Stochastic backpropagation and approximate inference in deep generative models," in Proceedings of The 31st International Conference on Machine Learning, 2014, (Vol. 32, pp. 1278–1286). Bejing, China: PMLR. http://doi.org/10.1051/00046361/201527329 
[6]  D. Rezende and S. Mohamed, "Variational Inference with Normalizing Flows," in Proceedings of the 32nd International Conference on Machine Learning, 2015, vol. 37, pp. 1530–1538. 
[7]  Y. Burda, R. Grosse, and R. Salakhutdinov, "Importance Weighted Autoencoders," in Proceedings of the 3rd International Conference on Learning Representations (ICLR), 2015. 
[8]  E. Jang, S. Gu, and B. Poole, "Categorical Reparameterization with GumbelSoftmax," Nov. 2016. in Proceedings of the 5th International Conference on Learning Representations (ICLR), 2017. 
[9]  (1, 2) L. Mescheder, S. Nowozin, and A. Geiger, "Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks," in Proceedings of the 34th International Conference on Machine Learning, 2017, vol. 70, pp. 2391–2400. 
[10]  (1, 2) D. Tran, R. Ranganath, and D. Blei, "Hierarchical Implicit Models and LikelihoodFree Variational Inference," to appear in Advances in Neural Information Processing Systems 31, 2017. 
Appendix
Below, you can find:
 The accompanying Jupyter Notebook used to generate the diagrams and plots in this post.
 The above snippets combined in a single executable Python file:
vae/variational_autoencoder_improved.py (Source)
import numpy as np import matplotlib.pyplot as plt from scipy.stats import norm from keras import backend as K from keras.layers import Input, Dense, Lambda, Layer, Add, Multiply from keras.models import Model, Sequential from keras.datasets import mnist batch_size = 100 original_dim = 784 latent_dim = 2 intermediate_dim = 256 epochs = 50 epsilon_std = 1.0 def nll(y_true, y_pred): """ Bernoulli negative log likelihood. """ # keras.losses.binary_crossentropy gives the mean # over the last axis. We require the sum. return K.sum(K.binary_crossentropy(y_true, y_pred), axis=1) class KLDivergenceLayer(Layer): """ Identity transform layer that adds KL divergence to the final model loss. """ def __init__(self, *args, **kwargs): self.is_placeholder = True super(KLDivergenceLayer, self).__init__(*args, **kwargs) def call(self, inputs): mu, log_var = inputs kl_batch =  .5 * K.sum(1 + log_var  K.square(mu)  K.exp(log_var), axis=1) self.add_loss(K.mean(kl_batch), inputs=inputs) return inputs x = Input(shape=(original_dim,)) h = Dense(intermediate_dim, activation='relu')(x) z_mu = Dense(latent_dim)(h) z_log_var = Dense(latent_dim)(h) z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var]) z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var) eps = Input(tensor=K.random_normal(stddev=epsilon_std, shape=(K.shape(x)[0], latent_dim))) z_eps = Multiply()([z_sigma, eps]) z = Add()([z_mu, z_eps]) decoder = Sequential([ Dense(intermediate_dim, input_dim=latent_dim, activation='relu'), Dense(original_dim, activation='sigmoid') ]) x_mean = decoder(z) vae = Model(inputs=[x, eps], outputs=x_mean) vae.compile(optimizer='rmsprop', loss=nll) # train the VAE on MNIST digits (x_train, y_train), (x_test, y_test) = mnist.load_data() x_train = x_train.reshape(1, original_dim) / 255. x_test = x_test.reshape(1, original_dim) / 255. vae.fit(x_train, x_train, shuffle=True, epochs=epochs, batch_size=batch_size, validation_data=(x_test, x_test)) encoder = Model(x, z_mu) # display a 2D plot of the digit classes in the latent space z_test = encoder.predict(x_test, batch_size=batch_size) plt.figure(figsize=(6, 6)) plt.scatter(z_test[:, 0], z_test[:, 1], c=y_test, alpha=.4, s=3**2, cmap='viridis') plt.colorbar() plt.show() # display a 2D manifold of the digits n = 15 # figure with 15x15 digits digit_size = 28 # linearly spaced coordinates on the unit square were transformed # through the inverse CDF (ppf) of the Gaussian to produce values # of the latent variables z, since the prior of the latent space # is Gaussian u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, n), np.linspace(0.05, 0.95, n))) z_grid = norm.ppf(u_grid) x_decoded = decoder.predict(z_grid.reshape(n*n, 2)) x_decoded = x_decoded.reshape(n, n, digit_size, digit_size) plt.figure(figsize=(10, 10)) plt.imshow(np.block(list(map(list, x_decoded))), cmap='gray') plt.show()
Comments
Comments powered by Disqus