variational_autoencoder-checkpoint.ipynb (Source)

Preamble

In [1]:
%matplotlib notebook
In [2]:
import numpy as np
from scipy.stats import norm

import matplotlib.pyplot as plt

from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import (Dense, Input, Layer, Lambda, 
                          Add, Multiply)

from keras.datasets import mnist
from keras.utils.vis_utils import model_to_dot, plot_model

from IPython.display import SVG
Using TensorFlow backend.

Notebook Configuration

In [3]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [4]:
'TensorFlow version: ' + K.tf.__version__
Out[4]:
'TensorFlow version: 1.3.0'
Constant definitions
In [5]:
mc_samples = 5
batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0

Variational Autoencoder

Model definition

Encoder

Reparameterization

noise as auxiliary input to the network

In [6]:
x = Input(shape=(original_dim,), name='x')
In [7]:
h = Dense(intermediate_dim, activation='relu', name='hidden')(x)
In [8]:
z_mu = Dense(latent_dim, name='mu')(h)
z_log_var = Dense(latent_dim, name='log_var')(h)
In [9]:
class KLDivergenceLayer(Layer):

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl = - .5 * K.sum(1 + K.print_tensor(log_var, message='prick') 
                            - K.square(mu) 
                            - K.exp(log_var), axis=-1)

        self.add_loss(kl, inputs=inputs)

        return inputs
In [10]:
z_mu, z_log_var = KLDivergenceLayer(name='kl')([z_mu, z_log_var])
In [11]:
sigma = Lambda(lambda x: K.exp(.5*x), name='sigma')(z_log_var)
In [49]:
# z_mean = Input(shape=(latent_dim,), name='mu')
# z_std_dev = Input(shape=(latent_dim,), name='sigma')

# eps = Input(shape=(mc_samples, latent_dim), name='eps')
In [36]:
# z_eps = Multiply(name='z_eps')([z_std_dev, eps])
# z = Add(name='z')([z_mean, z_eps])
In [37]:
# m = Model(inputs=[eps, z_mean, z_std_dev], outputs=z)
In [38]:
# SVG(model_to_dot(m, show_shapes=False).create(prog='dot', format='svg'))
Out[38]:
G 139943948126192 sigma: InputLayer139943948183984 z_eps: Multiply139943948126192->139943948183984 139943948125912 eps: InputLayer139943948125912->139943948183984 139943948125464 mu: InputLayer139943948183928 z: Add139943948125464->139943948183928 139943948183984->139943948183928
In [41]:
# plot_model(
#     model=m, show_shapes=False,
#     to_file='../images/vae/reparameterization.svg'
# )
In [42]:
# plot_model(
#     model=m, show_shapes=True,
#     to_file='../images/vae/reparameterization_shapes.svg'
# )
In [13]:
# eps = Input(shape=(n_samples, latent_dim,), name='epsilon')
eps = Input(shape=(mc_samples, latent_dim), name='epsilon')
sigma_eps = Multiply(name='sigma_eps')([sigma, eps])
z = Add(name='z')([z_mu, sigma_eps])
In [14]:
encoder = Model(inputs=[x, eps], outputs=z)
In [15]:
SVG(model_to_dot(encoder, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[15]:
G 4763636904 x: InputLayerinput:output:(None, 784)(None, 784)4763637688 hidden: Denseinput:output:(None, 784)(None, 256)4763636904->4763637688 4763869136 mu: Denseinput:output:(None, 256)(None, 2)4763637688->4763869136 4763865216 log_var: Denseinput:output:(None, 256)(None, 2)4763637688->4763865216 4764444600 kl: KLDivergenceLayerinput:output:[(None, 2), (None, 2)][(None, 2), (None, 2)]4763869136->4764444600 4763865216->4764444600 4764444712 sigma: Lambdainput:output:(None, 2)(None, 2)4764444600->4764444712 4765214704 z: Addinput:output:[(None, 2), (None, 5, 2)](None, 5, 2)4764444600->4765214704 4765213360 sigma_eps: Multiplyinput:output:[(None, 2), (None, 5, 2)](None, 5, 2)4764444712->4765213360 4765213248 epsilon: InputLayerinput:output:(None, 5, 2)(None, 5, 2)4765213248->4765213360 4765213360->4765214704

Decoder

In [16]:
# decoder = Sequential([
#     Dense(intermediate_dim, activation='relu', input_dim=latent_dim),
#     Dense(original_dim, activation='sigmoid')
# ], name='decoder')
In [17]:
# x_decoded_mean = decoder(z)
In [18]:
# SVG(model_to_dot(decoder, show_shapes=True)
#     .create(prog='dot', format='svg'))
In [26]:
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
In [27]:
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
In [39]:
def nll(y_true, y_pred):
    """ Negative log likelihood. """

    # keras.losses.binary_crossentropy give the mean
    # over the last axis. we require the sum
    return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [40]:
vae = Model(inputs=[x, eps], outputs=x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=nll)
In [41]:
SVG(model_to_dot(vae, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[41]:
G 4763636904 x: InputLayerinput:output:(None, 784)(None, 784)4763637688 hidden: Denseinput:output:(None, 784)(None, 256)4763636904->4763637688 4763869136 mu: Denseinput:output:(None, 256)(None, 2)4763637688->4763869136 4763865216 log_var: Denseinput:output:(None, 256)(None, 2)4763637688->4763865216 4764444600 kl: KLDivergenceLayerinput:output:[(None, 2), (None, 2)][(None, 2), (None, 2)]4763869136->4764444600 4763865216->4764444600 4764444712 sigma: Lambdainput:output:(None, 2)(None, 2)4764444600->4764444712 4765214704 z: Addinput:output:[(None, 2), (None, 5, 2)](None, 5, 2)4764444600->4765214704 4765213360 sigma_eps: Multiplyinput:output:[(None, 2), (None, 5, 2)](None, 5, 2)4764444712->4765213360 4765213248 epsilon: InputLayerinput:output:(None, 5, 2)(None, 5, 2)4765213248->4765213360 4765213360->4765214704 4768902168 dense_3: Denseinput:output:(None, 5, 2)(None, 5, 256)4765214704->4768902168 4768902224 dense_4: Denseinput:output:(None, 5, 256)(None, 5, 784)4768902168->4768902224
In [42]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28*28) / 255.
X_test = X_test.reshape(-1, 28*28) / 255.
In [43]:
vae.evaluate(
    [X_train, np.random.randn(len(X_train), mc_samples, latent_dim)],
    X_train,
    batch_size=batch_size
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-43-13501a5198e8> in <module>()
      2     [X_train, np.random.randn(len(X_train), mc_samples, latent_dim)],
      3     X_train,
----> 4     batch_size=batch_size
      5 )

~/.virtualenvs/anmoku/lib/python3.6/site-packages/keras/engine/training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps)
   1655                                batch_size=batch_size,
   1656                                verbose=verbose,
-> 1657                                steps=steps)
   1658 
   1659     def predict(self, x,

~/.virtualenvs/anmoku/lib/python3.6/site-packages/keras/engine/training.py in _test_loop(self, f, ins, batch_size, verbose, steps)
   1337                     ins_batch = _slice_arrays(ins, batch_ids)
   1338 
-> 1339                 batch_outs = f(ins_batch)
   1340                 if isinstance(batch_outs, list):
   1341                     if batch_index == 0:

~/.virtualenvs/anmoku/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2271         updated = session.run(self.outputs + [self.updates_op],
   2272                               feed_dict=feed_dict,
-> 2273                               **self.session_kwargs)
   2274         return updated[:len(self.outputs)]
   2275 

~/.virtualenvs/anmoku/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    893     try:
    894       result = self._run(None, fetches, feed_dict, options_ptr,
--> 895                          run_metadata_ptr)
    896       if run_metadata:
    897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.virtualenvs/anmoku/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1098                 'Cannot feed value of shape %r for Tensor %r, '
   1099                 'which has shape %r'
-> 1100                 % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
   1101           if not self.graph.is_feedable(subfeed_t):
   1102             raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (100, 784) for Tensor 'dense_4_target_2:0', which has shape '(?, ?, ?)'
In [24]:
vae.fit([X_train, np.random.randn(len(X_train), latent_dim)],
        X_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(
            [X_test, np.random.randn(len(X_test), latent_dim)], 
            X_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 2s - loss: 192.8234 - val_loss: 173.4639
Epoch 2/50
60000/60000 [==============================] - 2s - loss: 170.4935 - val_loss: 167.3251
Epoch 3/50
60000/60000 [==============================] - 2s - loss: 165.8410 - val_loss: 163.8605
Epoch 4/50
60000/60000 [==============================] - 2s - loss: 162.9716 - val_loss: 161.9201
Epoch 5/50
60000/60000 [==============================] - 2s - loss: 160.9723 - val_loss: 160.2319
Epoch 6/50
60000/60000 [==============================] - 2s - loss: 159.4687 - val_loss: 159.2952
Epoch 7/50
60000/60000 [==============================] - 2s - loss: 158.3019 - val_loss: 158.0032
Epoch 8/50
60000/60000 [==============================] - 2s - loss: 157.3528 - val_loss: 157.4063
Epoch 9/50
60000/60000 [==============================] - 2s - loss: 156.5618 - val_loss: 156.8355
Epoch 10/50
60000/60000 [==============================] - 2s - loss: 155.8394 - val_loss: 156.1229
Epoch 11/50
60000/60000 [==============================] - 2s - loss: 155.2358 - val_loss: 155.4725
Epoch 12/50
60000/60000 [==============================] - 2s - loss: 154.6807 - val_loss: 155.2970
Epoch 13/50
60000/60000 [==============================] - 2s - loss: 154.1845 - val_loss: 154.6605
Epoch 14/50
60000/60000 [==============================] - 2s - loss: 153.7407 - val_loss: 154.5573
Epoch 15/50
60000/60000 [==============================] - 2s - loss: 153.3144 - val_loss: 154.0561
Epoch 16/50
60000/60000 [==============================] - 2s - loss: 152.9552 - val_loss: 153.8598
Epoch 17/50
60000/60000 [==============================] - 2s - loss: 152.6090 - val_loss: 153.7119
Epoch 18/50
60000/60000 [==============================] - 2s - loss: 152.2861 - val_loss: 153.4686
Epoch 19/50
60000/60000 [==============================] - 2s - loss: 151.9980 - val_loss: 153.2624
Epoch 20/50
60000/60000 [==============================] - 2s - loss: 151.7103 - val_loss: 153.2426
Epoch 21/50
60000/60000 [==============================] - 2s - loss: 151.4483 - val_loss: 152.9442
Epoch 22/50
60000/60000 [==============================] - 2s - loss: 151.1993 - val_loss: 152.6045
Epoch 23/50
60000/60000 [==============================] - 2s - loss: 150.9613 - val_loss: 152.6727
Epoch 24/50
60000/60000 [==============================] - 2s - loss: 150.7605 - val_loss: 152.4743
Epoch 25/50
60000/60000 [==============================] - 2s - loss: 150.5422 - val_loss: 152.0614
Epoch 26/50
60000/60000 [==============================] - 2s - loss: 150.3498 - val_loss: 152.2016
Epoch 27/50
60000/60000 [==============================] - 2s - loss: 150.1937 - val_loss: 152.2652
Epoch 28/50
60000/60000 [==============================] - 2s - loss: 149.9897 - val_loss: 152.0452
Epoch 29/50
60000/60000 [==============================] - 2s - loss: 149.8584 - val_loss: 152.0970
Epoch 30/50
60000/60000 [==============================] - 2s - loss: 149.7068 - val_loss: 151.8346
Epoch 31/50
60000/60000 [==============================] - 2s - loss: 149.5380 - val_loss: 152.0878
Epoch 32/50
60000/60000 [==============================] - 2s - loss: 149.4005 - val_loss: 151.8262
Epoch 33/50
60000/60000 [==============================] - 2s - loss: 149.2798 - val_loss: 152.0060
Epoch 34/50
60000/60000 [==============================] - 2s - loss: 149.1350 - val_loss: 152.0361
Epoch 35/50
60000/60000 [==============================] - 2s - loss: 148.9873 - val_loss: 151.8470
Epoch 36/50
60000/60000 [==============================] - 2s - loss: 148.8849 - val_loss: 151.7099
Epoch 37/50
60000/60000 [==============================] - 2s - loss: 148.7897 - val_loss: 151.5309
Epoch 38/50
60000/60000 [==============================] - 2s - loss: 148.6630 - val_loss: 151.6342
Epoch 39/50
60000/60000 [==============================] - 2s - loss: 148.5674 - val_loss: 151.7563
Epoch 40/50
60000/60000 [==============================] - 2s - loss: 148.4517 - val_loss: 151.5432
Epoch 41/50
60000/60000 [==============================] - 2s - loss: 148.3366 - val_loss: 151.4317
Epoch 42/50
60000/60000 [==============================] - 2s - loss: 148.2159 - val_loss: 151.4681
Epoch 43/50
60000/60000 [==============================] - 2s - loss: 148.1392 - val_loss: 151.6582
Epoch 44/50
60000/60000 [==============================] - 2s - loss: 148.0562 - val_loss: 151.6221
Epoch 45/50
60000/60000 [==============================] - 2s - loss: 147.9545 - val_loss: 151.3868
Epoch 46/50
60000/60000 [==============================] - 2s - loss: 147.8461 - val_loss: 151.4615
Epoch 47/50
60000/60000 [==============================] - 2s - loss: 147.7605 - val_loss: 151.7655
Epoch 48/50
60000/60000 [==============================] - 2s - loss: 147.6813 - val_loss: 151.1806
Epoch 49/50
60000/60000 [==============================] - 2s - loss: 147.6052 - val_loss: 151.4509
Epoch 50/50
60000/60000 [==============================] - 2s - loss: 147.5226 - val_loss: 151.4373
Out[24]:
<keras.callbacks.History at 0x7fb4912a1fd0>
In [38]:
X_test_encoded = encoder.predict([X_test, np.random.randn(len(X_train), latent_dim)])
In [91]:
fig, ax = plt.subplots(figsize=(6, 5))

cbar = ax.scatter(X_test_encoded[:, 0], X_test_encoded[:, 1], 
                  c=y_test, alpha=.4, s=3**2,
                  cmap='viridis')

fig.colorbar(cbar, ax=ax)

plt.show()
In [40]:
n = 15  # figure with 15x15 digits
digit_size = 28
im = np.zeros((digit_size * n, digit_size * n))
In [53]:
# linearly spaced coordinates on the unit square were 
# transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the 
# prior of the latent space is Gaussian
u1, u2 = np.meshgrid(np.linspace(0.05, 0.95, n), 
                     np.linspace(0.05, 0.95, n))
u_grid = np.dstack((u1, u2))
z_grid = sp.stats.norm.ppf(u_grid)
In [55]:
z_grid.shape
Out[55]:
(15, 15, 2)
In [63]:
x_decoded = decoder.predict(z_grid.reshape(n*n, 2))
x_decoded.shape
Out[63]:
(225, 784)
In [92]:
x_decoded = x_decoded.reshape(n, n, digit_size, digit_size)
x_decoded.shape
Out[92]:
(15, 15, 28, 28)
In [93]:
for i in range(n):
    for j in range(n):
        im[i * digit_size: (i + 1) * digit_size,
           j * digit_size: (j + 1) * digit_size] = x_decoded[i, j]
In [94]:
fig, ax = plt.subplots(figsize=(7, 7))

# ax.imshow(np.reshape(x_decoded, (28*15, 28*15), order='A'), 
#           cmap='gray')

ax.imshow(im, cmap='gray')

plt.show()