Implementing Variational Autoencoders in Keras: Beyond the Quickstart Tutorial

Draft

Please do not share or link.

Keras is awesome. It is a very well-designed library that clearly abides by its guiding principles of modularity and extensibility, and allows us to easily assemble powerful, complex models from primitive building blocks. This has been demonstrated in numerous blog posts and tutorials, such as the excellent tutorial on Building Autoencoders in Keras. As the name suggests, that tutorial provides examples of how to implement various kinds of autoencoders in Keras, including the variational autoencoder (VAE) [1].

../../images/vae/result_combined.png

Visualization of 2D manifold of MNIST digits (left) and the representation of digits in latent space colored according to their digit labels (right).

Like all autoencoders, the variational autoencoder is primarily used for unsupervised learning of hidden representations. However, they are fundamentally different to your usual neural network-based autoencoder in that they approach the problem from a probabilistic perspective. They specify a joint distribution over the observed and latent variables, and approximate the intractable posterior conditional density over latent variables with variational inference, using an inference network [2] [3] (or more classically, a recognition model [4]) to amortize the cost of inference.

While the examples in the aforementioned tutorial do well to showcase the versatility of Keras on a wide range of autoencoder model architectures, its implementation of the variational autoencoder doesn't properly take advantage of Keras' modular design, making it difficult to generalize and extend in important ways. As we will see, it relies on implementing custom layers and constructs that are restricted to a specific instance of variational autoencoders. This is a shame because when combined, Keras' building blocks are powerful enough to encapsulate most variants of the variational autoencoder and more generally, recognition-generative model combinations for which the generative model belongs to a large family of deep latent Gaussian models (DLGMs) [5].

The goal of this post is to propose a clean and elegant alternative implementation that takes better advantage of Keras' modular design. It is not intended as tutorial on variational autoencoders [*]. Rather, we study variational autoencoders as a specific case of variational inference in deep latent Gaussian models with inference networks, and demonstrate how we can use Keras to implement them in a modular fashion such that they can be easily adapted to approximate inference in various common problems with different (non-Gaussian) likelihoods, such as classification with Bayesian logistic / softmax regression.

This first post will lay the groundwork for a series of future posts that explore ways to extend this basic modular framework to implement the more powerful methods proposed in the latest research, such as the normalizing flows for building richer posterior approximations [6], importance weighted autoencoders [7], the Gumbel-softmax trick for inference in discrete latent variables [8], and even the most recent GAN-based density-ratio estimation techniques for likelihood-free inference [9] [10].

Model specification

First, it is important to understand that the variational autoencoder is not a way to train generative models. Rather, the generative model is a component of the variational autoencoder and is, in general, a deep latent Gaussian model. In particular, let \(\mathbf{x}\) be a local observed variable and \(\mathbf{z}\) its corresponding local latent variable, with joint distribution

\begin{equation*} p_{\theta}(\mathbf{x}, \mathbf{z}) = p_{\theta}(\mathbf{x} | \mathbf{z}) p(\mathbf{z}). \end{equation*}

In Bayesian modelling, we assume the distribution of observed variables to be governed by the latent variables. Latent variables are drawn from a prior density \(p(\mathbf{z})\) and related to the observations though the likelihood \(p_{\theta}(\mathbf{x} | \mathbf{z})\). Deep latent Gaussian models (DLGMs) are a general class of models where the observed variable is governed by a hierarchy of latent variables, and the latent variables at each level of the hierarchy are Gaussian a priori [5].

In a typical instance of the variational autoencoder, we have only a single layer of latent variables with a Normal prior distribution,

\begin{equation*} p(\mathbf{z}) = \mathcal{N}(\mathbf{0}, \mathbf{I}). \end{equation*}

Now, each local latent variable is related to its corresponding observation through the likelihood \(p_{\theta}(\mathbf{x} | \mathbf{z})\), which can be viewed as a probabilistic decoder. Given a hidden lower-dimensional representation (or "code") \(\mathbf{z}\), it "decodes" it into a distribution over the observation \(\mathbf{x}\).

Decoder

In this example, we define \(p_{\theta}(\mathbf{x} | \mathbf{z})\) to be a multivariate Bernoulli whose probabilities are computed from \(\mathbf{z}\) using a fully-connected neural network with a single hidden layer,

\begin{align*} p_{\theta}(\mathbf{x} | \mathbf{z}) & = \mathrm{Bern}( \sigma( \mathbf{W}_2 \mathbf{h} + \mathbf{b}_2 ) ), \\ \mathbf{h} & = h(\mathbf{W}_1 \mathbf{z} + \mathbf{b}_1), \end{align*}

where \(\sigma\) is the logistic sigmoid function, \(h\) is some non-linearity, and the model parameters \(\theta = \{ \mathbf{W}_1, \mathbf{W}_2, \mathbf{b}_1, \mathbf{b}_1 \}\) consist of the weights and biases of this neural network.

It is straightforward to implement this in Keras with the Sequential model API:

decoder = Sequential([
  Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
  Dense(original_dim, activation='sigmoid')
])

You can view a summary of the model parameters \(\theta\) by calling decoder.summary(). Additionally, you can produce a high-level diagram of the network architecture, and optionally the input and output shapes of each layer using plot_model from the keras.utils.vis_utils module. Although our architecture is about as simple as it gets, it is included in the figure below as an example of what the diagrams look like.

../../images/vae/decoder.svg

Decoder architecture.

Note that by fixing \(\mathbf{W}_1\), \(\mathbf{b}_1\) and \(h\) to be the identity matrix, the zero vector, and the identity function, respectively (or equivalently dropping the first Dense layer in the snippet above altogether), we recover logistic factor analysis. With similarly minor modifications, we can recover other members from the family of DLGMs, which include non-linear factor analysis, non-linear Gaussian belief networks, sigmoid belief networks, and many others [5].

Inference

Having specified the generative process, we would now like to perform inference on the latent variables and model parameters, \(\mathbf{z}\) and \(\theta\), respectively. In particular, our goal is to compute the posterior \(p_{\theta}(\mathbf{z} | \mathbf{x})\), the conditional density of the latent variable \(\mathbf{z}\) given observed variable \(\mathbf{x}\). Additionally, we wish to optimize the model parameters \(\theta\) with respect to the marginal likelihood \(p_{\theta}(\mathbf{x})\). Both depend on the marginal likelihood, which requires marginalizing out the latent variables \(\mathbf{z}\). In general, this is computational intractable, requiring exponential time to compute. Or, it is analytically intractable and cannot be evaluated in closed-form, as it is in our case where the Gaussian prior is non-conjugate to the Bernoulli likelihood.

To circumvent this intractability we turn to variational inference, which formulates inference as an optimization problem. It seeks an approximate posterior \(q_{\phi}(\mathbf{z} | \mathbf{x})\) with variational parameters \(\phi\) closest in Kullback-Leibler (KL) divergence to the true posterior.

\begin{equation*} \phi^* = \mathrm{argmin}_{\phi} \mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) \| p_{\theta}(\mathbf{z} | \mathbf{x}) ] \end{equation*}

With the luck we've had so far, it shouldn't come as a surprise anymore that this too is intractable. It also depends on the log marginal likelihood, whose intractability is the reason we appealed to approximate inference in the first place. Instead, we maximize an alternative objective function, the evidence lower bound (ELBO), which is expressed as

\begin{align*} \mathrm{ELBO}(q) &= \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [ \log p_{\theta}(\mathbf{x} | \mathbf{z}) + \log p(\mathbf{z}) - \log q_{\phi}(\mathbf{z} | \mathbf{x}) ] \\ &= \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [ \log p_{\theta}(\mathbf{x} | \mathbf{z}) ] - \mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z}) ]. \end{align*}

Importantly, the ELBO is a lower bound to the log marginal likelihood. Therefore, maximizing it with respect to the model parameters \(\theta\) approximately maximizes the log marginal likelihood. Additionally, maximizing it with respect variational parameter \(\phi\) can be shown to minimize \(\mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) \| p_{\theta}(\mathbf{z} | \mathbf{x}) ]\). Also, it turns out that the KL divergence determines the tightness of the lower bound, where we have equality iff the KL divergence is zero, which happens iff \(q_{\phi}(\mathbf{z} | \mathbf{x}) = p_{\theta}(\mathbf{z} | \mathbf{x})\). Hence, simultaneously maximizing it with respect to \(\theta\) and \(\phi\) gets us two birds with one stone.

Encoder

Probabilistic encoder, inference network due to ..., recognition network, due to ...

Every local latent variable x_i corresponding to observed variable x_i has its own set of local variational parameters phi_i. For example, q_{phi_i}(z_i) = N(z_i | mu_i, diag(sigma_i^2)), with variational parameters phi_i = {mu_i, sigma_i}.

The number of local variational parameters grows with the size of the observed data. Furthermore, a new set of parameters must be optimized for unseen test data points. We amortize the cost of inference by introducing an inference network which outputs the local variational parameters phi_i given x_i as input. This approximation allows statistical strength to be shared across observed data-points and also generalize to unseen test points.

Continuing with the example, we have q_{phi}(z_i | x_i ) = N(z_i | mu_{phi}(x_i), diag(sigma_{phi}(x_i)^2)), with variational parameters phi_i = {mu_i, sigma_i}.

In the specific case of autoencoders, the network that maps latent code

More the general case of amortized variational inference, this is known as a recognition model, or an inference network.

\begin{equation*} q_{\phi}(\mathbf{z} | \mathbf{x}) = \mathcal{N}( \mathbf{z} | \mathbf{\mu}_{\phi}(\mathbf{x}), \mathrm{diag}(\mathbf{\sigma}_{\phi}^2(\mathbf{x})) ) \end{equation*}
x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)

z_mu = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)

figure here

Reparameterization using Merge Layers

To perform gradient-based optimization of ELBO, we require its gradients with respect to the variational parameters \(\phi\), which is generally intractable. Currently, the dominant approach for circumventing this is by Monte Carlo (MC) estimation of the gradients. There are a several estimators based on different variance reduction techniques. However, for continuous latent variables, the reparameterization gradients can be shown to have the lowest variance among competing estimators.

The ELBO can be written as an expectation of a multivariate function \(f(\mathbf{x}, \mathbf{z}) = \log p_{\theta}(\mathbf{x} , \mathbf{z}) - \log q_{\phi}(\mathbf{z} | \mathbf{x})\) over distribution \(q_{\phi}(\mathbf{z} | \mathbf{x})\).

\begin{align*} \nabla_{\phi} \mathbb{E}_{q_{\phi}(\mathbf{z} | \mathbf{x})} [ f(\mathbf{x}, \mathbf{z}) ] &= \nabla_{\phi} \mathbb{E}_{p(\mathbf{\epsilon})} [ f(\mathbf{x}, g_{\phi}(\mathbf{x}, \mathbf{\epsilon})) ] \\ &= \mathbb{E}_{p(\mathbf{\epsilon})} [ \nabla_{\phi} f(\mathbf{x}, g_{\phi}(\mathbf{x}, \mathbf{\epsilon})) ] \\ \end{align*}

Specifying gives us the gradient of the ELBO above.

\begin{equation*} z = g_{\phi}(\mathbf{x}, \mathbf{\epsilon}), \quad \mathbf{\epsilon} \sim p(\mathbf{\epsilon}) \end{equation*}
\begin{equation*} g_{\phi}(\mathbf{x}, \mathbf{\epsilon}) = \mathbf{\mu}_{\phi}(\mathbf{x}) + \mathbf{\sigma}_{\phi}(\mathbf{x}) \odot \mathbf{\epsilon}, \quad \mathbf{\epsilon} \sim \mathcal{N}(\mathbf{0}, \mathbf{I}) \end{equation*}

Assume z_mu and z_sigma are the outputs of some layers. Then, using Merge Layers, Add and Multiply:

eps = Input(shape=(latent_dim,))
z_eps = Multiply()([z_sigma, eps])

z = Add()([z_mu, z_eps])
../../images/vae/reparameterization.svg

Reparameterization with simple location-scale transformation using Keras merge layers.

Lambda layer, which simultaneously draws samples from a hard-coded base distribution and performs reparameterization. This implementation achieves a more appropriate level of modularity and abstraction. It's makes it clear that each of these atomic building blocks are themselves deterministic transformations which together make up a deterministic transformation. The source of stochasticity comes from the input, which we are able to tweak at test time. Gumbel-softmax trick.

eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0], latent_dim)))

For the sake of illustration, we've fixed sigma and mu as Input layers. That's why it says InputLayer next to it. In reality, it will be the output layer of a network. We specify \(\mathbf{\mu}_{\phi}(\mathbf{x})\) and \(\mathbf{\sigma}_{\phi}(\mathbf{x})\) now.

../../images/vae/encoder.svg

Encoder architecture.

KL Divergence

latent space regularization

\begin{equation*} \mathrm{KL} [q_{\phi}(\mathbf{z} | \mathbf{x}) \| p(\mathbf{z}) ] = - \frac{1}{2} \sum_{k=1}^K \{ 1 + \log \sigma_k^2 - \mu_k^2 - \sigma_k^2 \} \end{equation*}
class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl_batch = - .5 * K.sum(1 + log_var -
                                K.square(mu) -
                                K.exp(log_var), axis=-1)

        self.add_loss(K.mean(kl_batch), inputs=inputs)

        return inputs
z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])

by itself, it will learn to ignore the input and map all outputs to 0. It is only when we tack on the decoder that the reconstruction likelihood is introduced. Only then will we reconcile the likelihood / observed data with our prior to form the posterior over latent codes.

At this stage we could specify prob_encoder = Model(inputs=x, outputs=[z_mu, z_sigma]) and compile it with something like prob_encoder.compile(optimizer='rmsprop`, loss=None). When we fit it, it would trivially map all inputs to 0 and 1, thus learning the prior distribution.

inputs mu and log_var are of shape (batch_size, latent_dim) the loss we add should be scalar. this is unlike loss function specified in model compile which should returns loss vector of shape (batch_size,) since it requires loss for each datapoint in the batch for sample weighting.

../../images/vae/encoder_full.svg

Full encoder architecture, including auxiliary KL divergence layer.

Putting it all together

x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)

z_mu = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)

eps = Input(tensor=K.random_normal(shape=(K.shape(x)[0], latent_dim)))
z_eps = Multiply()([z_sigma, eps])
z = Add()([z_mu, z_eps])

decoder = Sequential([
    Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
    Dense(original_dim, activation='sigmoid')
])

x_mean = decoder(z)
vae = Model(inputs=[x, eps], outputs=x_mean)
vae.compile(optimizer='rmsprop', loss=nll)
../../images/vae/vae_full_shapes.svg

Variational autoencoder architecture.

When combined end-to-end, the inference network and the deep latent Gaussian model can be seen as having an autoencoder structure. Indeed, this general structure contains the variational autoencoder as a special case, and more traditionally, the Helmholtz machine. Even more generally, we can use this structure to perform amortized variational inference in complex generative models for a wide array of supervised, unsupervised and semi-supervised tasks.

The point of this tutorial is to illustrate the general framework for performing amortized variational inference using Keras, treating the inference network (approximate posterior) and the generative network (likelihood) as black-boxes. What we've used for the encoder and decoder each with a single hidden full-connected layer is perhaps the minimal viable architecture. In the examples directory, Keras provides a more sophisticated variational autoencoder with deconvolutional layers. The architecture definitions can be trivially copy-pasted here without need to modify anything else.

Parameter Learning

We load the training data as usual. Now the vae is explicitly specified with random noise source as an auxiliary input. This allows to easily control the base distribution \(p(\mathbf{\epsilon})\) and also how we draw Monte Carlo samples of \(\mathbf{z}\) for each datapoint \(\mathbf{x}\). Usually we just stick with a simple isotropic Gaussian distribution and draw a different MC sample for each datapoint.

(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.

Model fitting feels less intuitive. The vae is compiled with loss=None explicitly specified which raises a warning. When fit is called, the targets argument is left unspecified, and the reconstruction loss is optimized through the CustomLayer. This mapping from mathematical problem formulation to code implementation appears more natural and straightforward. It's easy to understand at a glance from our call to the fit method that we're training a probabilistic auto-encoder.

vae.fit(x_train,
        x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test))

Personally, I prefer this view since the all sources of stochasticity emanate from the inputs to the model.

Loss (NELBO) Convergence

fig, ax = plt.subplots()

pd.DataFrame(hist.history).plot(ax=ax)

ax.set_ylabel('NELBO')
ax.set_xlabel('# epochs')

plt.show()
../../images/vae/nelbo.svg

Model evaluation

encoder = Model(x, z_mu)

# display a 2D plot of the digit classes in the latent space
z_test = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(z_test[:, 0], z_test[:, 1], c=y_test,
            alpha=.4, s=3**2, cmap='viridis')
plt.colorbar()
plt.show()
../../images/vae/result_latent_space.png
# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28

# linearly spaced coordinates on the unit square were transformed
# through the inverse CDF (ppf) of the Gaussian to produce values
# of the latent variables z, since the prior of the latent space
# is Gaussian
u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, n),
                               np.linspace(0.05, 0.95, n)))
z_grid = norm.ppf(u_grid)
x_decoded = decoder.predict(z_grid.reshape(n*n, 2))
x_decoded = x_decoded.reshape(n, n, digit_size, digit_size)

plt.figure(figsize=(10, 10))
plt.imshow(np.block(list(map(list, x_decoded))), cmap='gray')
plt.show()
../../images/vae/result_manifold.png

Recap

  • Demonstration of Sequential and functional Model API
  • Custom auxiliary layers that augments the model loss
  • Fixing input to source of stochasticity
  • Reparameterization using Merge layers

What's next

Normalizing flows

We illustrate how to employ the simple Gumbel-Softmax reparameterization to build a Categorical VAE with discrete latent variables.

We can easily extend KLDivergenceLayer to use an auxiliary density ratio estimator function, instead of evaluating the KL divergence in the closed-form expression above. This relaxes the requirement on approximate posterior \(q(\mathbf{z}|\mathbf{x})\) (and incidentally, also prior \(p(\mathbf{z})\)) to yield tractable densities, at the cost of maximizing a cruder estimate of the ELBO. This is known as Adversarial Variational Bayes [9], and is an important line of recent research that extends the applicability of variational inference to arbitrarily expressive implicit probabilistic models [10].

Footnotes

[*]

For a complete treatment of variational autoencoders, and variational inference in general, I highly recommend:

References

[1] D. P. Kingma and M. Welling, "Auto-Encoding Variational Bayes," in Proceedings of the 2nd International Conference on Learning Representations (ICLR), 2014.
[2] Edward tutorial on Inference Networks
[3] Section "Recognition models and amortised inference" in Shakir's blog post.
[4] Dayan, P., Hinton, G. E., Neal, R. M., & Zemel, R. S. (1995). The Helmholtz machine. Neural Computation, 7(5), 889–904. http://doi.org/10.1162/neco.1995.7.5.889
[5] (1, 2, 3) Rezende, D. J., Mohamed, S., & Wierstra, D. (2014). "Stochastic backpropagation and approximate inference in deep generative models," in Proceedings of The 31st International Conference on Machine Learning, 2014, (Vol. 32, pp. 1278–1286). Bejing, China: PMLR. http://doi.org/10.1051/0004-6361/201527329
[6] D. Rezende and S. Mohamed, "Variational Inference with Normalizing Flows," in Proceedings of the 32nd International Conference on Machine Learning, 2015, vol. 37, pp. 1530–1538.
[7] Y. Burda, R. Grosse, and R. Salakhutdinov, "Importance Weighted Autoencoders," in Proceedings of the 3rd International Conference on Learning Representations (ICLR), 2015.
[8] E. Jang, S. Gu, and B. Poole, "Categorical Reparameterization with Gumbel-Softmax," Nov. 2016. in Proceedings of the 5th International Conference on Learning Representations (ICLR), 2017.
[9] (1, 2) L. Mescheder, S. Nowozin, and A. Geiger, "Adversarial Variational Bayes: Unifying Variational Autoencoders and Generative Adversarial Networks," in Proceedings of the 34th International Conference on Machine Learning, 2017, vol. 70, pp. 2391–2400.
[10] (1, 2) D. Tran, R. Ranganath, and D. Blei, "Hierarchical Implicit Models and Likelihood-Free Variational Inference," to appear in Advances in Neural Information Processing Systems 31, 2017.

Appendix

Below, you can find:

  • The accompanying Jupyter Notebook used to generate the diagrams and plots in this post.
  • The above snippets combined in a single executable Python file:

vae/variational_autoencoder_improved.py (Source)

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras import backend as K

from keras.layers import Input, Dense, Lambda, Layer, Add, Multiply
from keras.models import Model, Sequential
from keras.datasets import mnist


batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0


def nll(y_true, y_pred):
    """ Bernoulli negative log likelihood. """

    # keras.losses.binary_crossentropy gives the mean
    # over the last axis. We require the sum.
    return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)


class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl_batch = - .5 * K.sum(1 + log_var -
                                K.square(mu) -
                                K.exp(log_var), axis=-1)

        self.add_loss(K.mean(kl_batch), inputs=inputs)

        return inputs

x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)

z_mu = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)

z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)

eps = Input(tensor=K.random_normal(stddev=epsilon_std,
                                   shape=(K.shape(x)[0], latent_dim)))
z_eps = Multiply()([z_sigma, eps])
z = Add()([z_mu, z_eps])

decoder = Sequential([
    Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
    Dense(original_dim, activation='sigmoid')
])

x_mean = decoder(z)

vae = Model(inputs=[x, eps], outputs=x_mean)
vae.compile(optimizer='rmsprop', loss=nll)

# train the VAE on MNIST digits
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.

vae.fit(x_train,
        x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(x_test, x_test))

encoder = Model(x, z_mu)

# display a 2D plot of the digit classes in the latent space
z_test = encoder.predict(x_test, batch_size=batch_size)
plt.figure(figsize=(6, 6))
plt.scatter(z_test[:, 0], z_test[:, 1], c=y_test,
            alpha=.4, s=3**2, cmap='viridis')
plt.colorbar()
plt.show()

# display a 2D manifold of the digits
n = 15  # figure with 15x15 digits
digit_size = 28

# linearly spaced coordinates on the unit square were transformed
# through the inverse CDF (ppf) of the Gaussian to produce values
# of the latent variables z, since the prior of the latent space
# is Gaussian
u_grid = np.dstack(np.meshgrid(np.linspace(0.05, 0.95, n),
                               np.linspace(0.05, 0.95, n)))
z_grid = norm.ppf(u_grid)
x_decoded = decoder.predict(z_grid.reshape(n*n, 2))
x_decoded = x_decoded.reshape(n, n, digit_size, digit_size)

plt.figure(figsize=(10, 10))
plt.imshow(np.block(list(map(list, x_decoded))), cmap='gray')
plt.show()

Comments

Comments powered by Disqus