variational_autoencoder-checkpoint.ipynb (Source)

Preamble

In [1]:
%matplotlib notebook
In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import (Input, Dense, Lambda, Layer, 
                          Add, Multiply)
from keras.models import Model, Sequential
from keras import backend as K
from keras.datasets import mnist

from keras.utils.vis_utils import model_to_dot, plot_model

from IPython.display import SVG
Using TensorFlow backend.

Notebook Configuration

In [3]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [4]:
'TensorFlow version: ' + K.tf.__version__
Out[4]:
'TensorFlow version: 1.3.0'
Constant definitions
In [5]:
mc_samples = 5
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0

Variational Autoencoder

Model definition

Encoder

Illustration 1: Reparameterization using Keras Layers
In [6]:
z_mu = Input(shape=(latent_dim,), name='mu')
z_sigma = Input(shape=(latent_dim,), name='sigma')
In [7]:
eps = Input(shape=(latent_dim,), name='eps')
z_eps = Multiply(name='z_eps')([z_sigma, eps])
z = Add(name='z')([z_mu, z_eps])
In [8]:
m = Model(inputs=[eps, z_mu, z_sigma], outputs=z)
In [9]:
SVG(model_to_dot(m, show_shapes=False).create(prog='dot', format='svg'))
Out[9]:
G 140434992610384 sigma: InputLayer140434992722776 z_eps: Multiply140434992610384->140434992722776 140434992722832 eps: InputLayer140434992722832->140434992722776 140434992610440 mu: InputLayer140434992723392 z: Add140434992610440->140434992723392 140434992722776->140434992723392
In [10]:
plot_model(
    model=m, show_shapes=False,
    to_file='../../images/vae/reparameterization.svg'
)
In [11]:
plot_model(
    model=m, show_shapes=False,
    to_file='../../images/vae/reparameterization.svg'
)
Illustration 2: Encoder architecture
In [12]:
x = Input(shape=(original_dim,), name='x')
h = Dense(intermediate_dim, activation='relu', name='encoder_hidden')(x)
z_mu = Dense(latent_dim, name='mu')(h)
z_log_var = Dense(latent_dim, name='log_var')(h)
z_sigma = Lambda(lambda t: K.exp(.5*t), name='sigma')(z_log_var)
In [13]:
eps = Input(shape=(latent_dim,), name='epsilon')
z_eps = Multiply(name='z_eps')([z_sigma, eps])
z = Add(name='z')([z_mu, z_eps])
In [14]:
encoder = Model(inputs=[x, eps], outputs=z)
In [15]:
SVG(model_to_dot(encoder, show_shapes=False)
    .create(prog='dot', format='svg'))
Out[15]:
G 140434991284912 x: InputLayer140434991287264 encoder_hidden: Dense140434991284912->140434991287264 140434991287600 log_var: Dense140434991287264->140434991287600 140434991288272 mu: Dense140434991287264->140434991288272 140434991285472 sigma: Lambda140434991287600->140434991285472 140434991784008 z_eps: Multiply140434991285472->140434991784008 140434991786248 epsilon: InputLayer140434991786248->140434991784008 140434991785296 z: Add140434991288272->140434991785296 140434991784008->140434991785296
In [16]:
plot_model(
    model=encoder, show_shapes=False,
    to_file='../../images/vae/encoder.svg'
)
In [17]:
plot_model(
    model=encoder, show_shapes=True,
    to_file='../../images/vae/encoder_shapes.svg'
)
Illustration 3: Full Encoder architecture with auxiliary layers
In [18]:
class KLDivergenceLayer(Layer):

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl = - .5 * K.sum(1 + K.print_tensor(log_var, message='prick') 
                            - K.square(mu) 
                            - K.exp(log_var), axis=-1)

        self.add_loss(kl, inputs=inputs)

        return inputs
In [19]:
z_mu, z_log_var = KLDivergenceLayer(name='kl')([z_mu, z_log_var])
z_sigma = Lambda(lambda t: K.exp(.5*t), name='sigma')(z_log_var)
In [20]:
eps = Input(shape=(latent_dim,), name='epsilon')
z_eps = Multiply(name='sigma_eps')([z_sigma, eps])
z = Add(name='z')([z_mu, z_eps])
In [21]:
encoder = Model(inputs=[x, eps], outputs=z)
In [22]:
SVG(model_to_dot(encoder, show_shapes=False)
    .create(prog='dot', format='svg'))
Out[22]:
G 140434991284912 x: InputLayer140434991287264 encoder_hidden: Dense140434991284912->140434991287264 140434991288272 mu: Dense140434991287264->140434991288272 140434991287600 log_var: Dense140434991287264->140434991287600 140434992459672 kl: KLDivergenceLayer140434991288272->140434992459672 140434991287600->140434992459672 140434992808568 sigma: Lambda140434992459672->140434992808568 140434992456368 z: Add140434992459672->140434992456368 140434992457992 sigma_eps: Multiply140434992808568->140434992457992 140434992456424 epsilon: InputLayer140434992456424->140434992457992 140434992457992->140434992456368
In [23]:
plot_model(
    model=encoder, show_shapes=False,
    to_file='../../images/vae/encoder_full.svg'
)
In [24]:
plot_model(
    model=encoder, show_shapes=True,
    to_file='../../images/vae/encoder_full_shapes.svg'
)

Decoder

In [25]:
decoder = Sequential([
    Dense(intermediate_dim, input_dim=latent_dim, 
          activation='relu', name='decoder_hidden'),
    Dense(original_dim, activation='sigmoid', name='x_mean')
], name='decoder')
In [26]:
# equivalent to above. Writing InputLayer 
# explicitly for diagram 
decoder = Sequential([
    InputLayer(input_shape=(latent_dim,), name='z'),
    Dense(intermediate_dim, activation='relu', name='decoder_hidden'),
    Dense(original_dim, activation='sigmoid', name='x_mean')
], name='decoder')
In [27]:
x_decoded = decoder(z)
In [28]:
SVG(model_to_dot(decoder, show_shapes=False)
    .create(prog='dot', format='svg'))
Out[28]:
G 140434990562160 z: InputLayer140434991557656 decoder_hidden: Dense140434990562160->140434991557656 140434991649456 x_mean: Dense140434991557656->140434991649456
In [29]:
plot_model(
    model=decoder, show_shapes=False,
    to_file='../../images/vae/decoder.svg'
)
In [30]:
plot_model(
    model=decoder, show_shapes=True,
    to_file='../../images/vae/decoder_shapes.svg'
)
In [31]:
# again, equivalent to above. writing out fully
# for final end-to-end vae architecture visualization;
# otherwise, sequential models just get chunked into
# single layer
h_decoded = Dense(intermediate_dim, 
                  activation='relu', 
                  name='decoder_hidden')(z)
x_decoded = Dense(original_dim, 
                  activation='sigmoid', 
                  name='x_mean')(h_decoded)
In [32]:
def nll(y_true, y_pred):
    """ Negative log likelihood. """

    # keras.losses.binary_crossentropy give the mean
    # over the last axis. we require the sum
    return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [33]:
vae = Model(inputs=[x, eps], outputs=x_decoded)
vae.compile(optimizer='rmsprop', loss=nll)
In [34]:
SVG(model_to_dot(vae, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[34]:
G 140434991284912 x: InputLayerinput:output:(None, 784)(None, 784)140434991287264 encoder_hidden: Denseinput:output:(None, 784)(None, 256)140434991284912->140434991287264 140434991288272 mu: Denseinput:output:(None, 256)(None, 2)140434991287264->140434991288272 140434991287600 log_var: Denseinput:output:(None, 256)(None, 2)140434991287264->140434991287600 140434992459672 kl: KLDivergenceLayerinput:output:[(None, 2), (None, 2)][(None, 2), (None, 2)]140434991288272->140434992459672 140434991287600->140434992459672 140434992808568 sigma: Lambdainput:output:(None, 2)(None, 2)140434992459672->140434992808568 140434992456368 z: Addinput:output:[(None, 2), (None, 2)](None, 2)140434992459672->140434992456368 140434992457992 sigma_eps: Multiplyinput:output:[(None, 2), (None, 2)](None, 2)140434992808568->140434992457992 140434992456424 epsilon: InputLayerinput:output:(None, 2)(None, 2)140434992456424->140434992457992 140434992457992->140434992456368 140434989786952 decoder_hidden: Denseinput:output:(None, 2)(None, 256)140434992456368->140434989786952 140434989788016 x_mean: Denseinput:output:(None, 256)(None, 784)140434989786952->140434989788016
In [35]:
plot_model(
    model=vae, show_shapes=False,
    to_file='../../images/vae/vae_full.svg'
)
In [36]:
plot_model(
    model=vae, show_shapes=True,
    to_file='../../images/vae/vae_full_shapes.svg'
)
In [42]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28*28) / 255.
X_test = X_test.reshape(-1, 28*28) / 255.
In [44]:
X_train.shape
Out[44]:
(60000, 784)
In [46]:
?np.repeat
In [47]:
np.repeat(X_train, repeats=(1, 5, 1)).shape
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-47-65f96a1e5698> in <module>()
----> 1 np.repeat(X_train, repeats=(1, 5, 1)).shape

~/.virtualenvs/anmoku/lib/python3.6/site-packages/numpy/core/fromnumeric.py in repeat(a, repeats, axis)
    396 
    397     """
--> 398     return _wrapfunc(a, 'repeat', repeats, axis=axis)
    399 
    400 

~/.virtualenvs/anmoku/lib/python3.6/site-packages/numpy/core/fromnumeric.py in _wrapfunc(obj, method, *args, **kwds)
     55 def _wrapfunc(obj, method, *args, **kwds):
     56     try:
---> 57         return getattr(obj, method)(*args, **kwds)
     58 
     59     # An AttributeError occurs if the object does not have

ValueError: operands could not be broadcast together with shape (47040000,) (3,)
In [43]:
vae.evaluate(
    [X_train, np.random.randn(len(X_train), mc_samples, latent_dim)],
    X_train,
    batch_size=batch_size
)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
<ipython-input-43-13501a5198e8> in <module>()
      2     [X_train, np.random.randn(len(X_train), mc_samples, latent_dim)],
      3     X_train,
----> 4     batch_size=batch_size
      5 )

~/.virtualenvs/anmoku/lib/python3.6/site-packages/keras/engine/training.py in evaluate(self, x, y, batch_size, verbose, sample_weight, steps)
   1655                                batch_size=batch_size,
   1656                                verbose=verbose,
-> 1657                                steps=steps)
   1658 
   1659     def predict(self, x,

~/.virtualenvs/anmoku/lib/python3.6/site-packages/keras/engine/training.py in _test_loop(self, f, ins, batch_size, verbose, steps)
   1337                     ins_batch = _slice_arrays(ins, batch_ids)
   1338 
-> 1339                 batch_outs = f(ins_batch)
   1340                 if isinstance(batch_outs, list):
   1341                     if batch_index == 0:

~/.virtualenvs/anmoku/lib/python3.6/site-packages/keras/backend/tensorflow_backend.py in __call__(self, inputs)
   2271         updated = session.run(self.outputs + [self.updates_op],
   2272                               feed_dict=feed_dict,
-> 2273                               **self.session_kwargs)
   2274         return updated[:len(self.outputs)]
   2275 

~/.virtualenvs/anmoku/lib/python3.6/site-packages/tensorflow/python/client/session.py in run(self, fetches, feed_dict, options, run_metadata)
    893     try:
    894       result = self._run(None, fetches, feed_dict, options_ptr,
--> 895                          run_metadata_ptr)
    896       if run_metadata:
    897         proto_data = tf_session.TF_GetBuffer(run_metadata_ptr)

~/.virtualenvs/anmoku/lib/python3.6/site-packages/tensorflow/python/client/session.py in _run(self, handle, fetches, feed_dict, options, run_metadata)
   1098                 'Cannot feed value of shape %r for Tensor %r, '
   1099                 'which has shape %r'
-> 1100                 % (np_val.shape, subfeed_t.name, str(subfeed_t.get_shape())))
   1101           if not self.graph.is_feedable(subfeed_t):
   1102             raise ValueError('Tensor %s may not be fed.' % subfeed_t)

ValueError: Cannot feed value of shape (100, 784) for Tensor 'dense_4_target_2:0', which has shape '(?, ?, ?)'
In [24]:
vae.fit([X_train, np.random.randn(len(X_train), latent_dim)],
        X_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_data=(
            [X_test, np.random.randn(len(X_test), latent_dim)], 
            X_test))
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 2s - loss: 192.8234 - val_loss: 173.4639
Epoch 2/50
60000/60000 [==============================] - 2s - loss: 170.4935 - val_loss: 167.3251
Epoch 3/50
60000/60000 [==============================] - 2s - loss: 165.8410 - val_loss: 163.8605
Epoch 4/50
60000/60000 [==============================] - 2s - loss: 162.9716 - val_loss: 161.9201
Epoch 5/50
60000/60000 [==============================] - 2s - loss: 160.9723 - val_loss: 160.2319
Epoch 6/50
60000/60000 [==============================] - 2s - loss: 159.4687 - val_loss: 159.2952
Epoch 7/50
60000/60000 [==============================] - 2s - loss: 158.3019 - val_loss: 158.0032
Epoch 8/50
60000/60000 [==============================] - 2s - loss: 157.3528 - val_loss: 157.4063
Epoch 9/50
60000/60000 [==============================] - 2s - loss: 156.5618 - val_loss: 156.8355
Epoch 10/50
60000/60000 [==============================] - 2s - loss: 155.8394 - val_loss: 156.1229
Epoch 11/50
60000/60000 [==============================] - 2s - loss: 155.2358 - val_loss: 155.4725
Epoch 12/50
60000/60000 [==============================] - 2s - loss: 154.6807 - val_loss: 155.2970
Epoch 13/50
60000/60000 [==============================] - 2s - loss: 154.1845 - val_loss: 154.6605
Epoch 14/50
60000/60000 [==============================] - 2s - loss: 153.7407 - val_loss: 154.5573
Epoch 15/50
60000/60000 [==============================] - 2s - loss: 153.3144 - val_loss: 154.0561
Epoch 16/50
60000/60000 [==============================] - 2s - loss: 152.9552 - val_loss: 153.8598
Epoch 17/50
60000/60000 [==============================] - 2s - loss: 152.6090 - val_loss: 153.7119
Epoch 18/50
60000/60000 [==============================] - 2s - loss: 152.2861 - val_loss: 153.4686
Epoch 19/50
60000/60000 [==============================] - 2s - loss: 151.9980 - val_loss: 153.2624
Epoch 20/50
60000/60000 [==============================] - 2s - loss: 151.7103 - val_loss: 153.2426
Epoch 21/50
60000/60000 [==============================] - 2s - loss: 151.4483 - val_loss: 152.9442
Epoch 22/50
60000/60000 [==============================] - 2s - loss: 151.1993 - val_loss: 152.6045
Epoch 23/50
60000/60000 [==============================] - 2s - loss: 150.9613 - val_loss: 152.6727
Epoch 24/50
60000/60000 [==============================] - 2s - loss: 150.7605 - val_loss: 152.4743
Epoch 25/50
60000/60000 [==============================] - 2s - loss: 150.5422 - val_loss: 152.0614
Epoch 26/50
60000/60000 [==============================] - 2s - loss: 150.3498 - val_loss: 152.2016
Epoch 27/50
60000/60000 [==============================] - 2s - loss: 150.1937 - val_loss: 152.2652
Epoch 28/50
60000/60000 [==============================] - 2s - loss: 149.9897 - val_loss: 152.0452
Epoch 29/50
60000/60000 [==============================] - 2s - loss: 149.8584 - val_loss: 152.0970
Epoch 30/50
60000/60000 [==============================] - 2s - loss: 149.7068 - val_loss: 151.8346
Epoch 31/50
60000/60000 [==============================] - 2s - loss: 149.5380 - val_loss: 152.0878
Epoch 32/50
60000/60000 [==============================] - 2s - loss: 149.4005 - val_loss: 151.8262
Epoch 33/50
60000/60000 [==============================] - 2s - loss: 149.2798 - val_loss: 152.0060
Epoch 34/50
60000/60000 [==============================] - 2s - loss: 149.1350 - val_loss: 152.0361
Epoch 35/50
60000/60000 [==============================] - 2s - loss: 148.9873 - val_loss: 151.8470
Epoch 36/50
60000/60000 [==============================] - 2s - loss: 148.8849 - val_loss: 151.7099
Epoch 37/50
60000/60000 [==============================] - 2s - loss: 148.7897 - val_loss: 151.5309
Epoch 38/50
60000/60000 [==============================] - 2s - loss: 148.6630 - val_loss: 151.6342
Epoch 39/50
60000/60000 [==============================] - 2s - loss: 148.5674 - val_loss: 151.7563
Epoch 40/50
60000/60000 [==============================] - 2s - loss: 148.4517 - val_loss: 151.5432
Epoch 41/50
60000/60000 [==============================] - 2s - loss: 148.3366 - val_loss: 151.4317
Epoch 42/50
60000/60000 [==============================] - 2s - loss: 148.2159 - val_loss: 151.4681
Epoch 43/50
60000/60000 [==============================] - 2s - loss: 148.1392 - val_loss: 151.6582
Epoch 44/50
60000/60000 [==============================] - 2s - loss: 148.0562 - val_loss: 151.6221
Epoch 45/50
60000/60000 [==============================] - 2s - loss: 147.9545 - val_loss: 151.3868
Epoch 46/50
60000/60000 [==============================] - 2s - loss: 147.8461 - val_loss: 151.4615
Epoch 47/50
60000/60000 [==============================] - 2s - loss: 147.7605 - val_loss: 151.7655
Epoch 48/50
60000/60000 [==============================] - 2s - loss: 147.6813 - val_loss: 151.1806
Epoch 49/50
60000/60000 [==============================] - 2s - loss: 147.6052 - val_loss: 151.4509
Epoch 50/50
60000/60000 [==============================] - 2s - loss: 147.5226 - val_loss: 151.4373
Out[24]:
<keras.callbacks.History at 0x7fb4912a1fd0>
In [38]:
X_test_encoded = encoder.predict([X_test, np.random.randn(len(X_train), latent_dim)])
In [91]:
fig, ax = plt.subplots(figsize=(6, 5))

cbar = ax.scatter(X_test_encoded[:, 0], X_test_encoded[:, 1], 
                  c=y_test, alpha=.4, s=3**2,
                  cmap='viridis')

fig.colorbar(cbar, ax=ax)

plt.show()
In [40]:
n = 15  # figure with 15x15 digits
digit_size = 28
im = np.zeros((digit_size * n, digit_size * n))
In [53]:
# linearly spaced coordinates on the unit square were 
# transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the 
# prior of the latent space is Gaussian
u1, u2 = np.meshgrid(np.linspace(0.05, 0.95, n), 
                     np.linspace(0.05, 0.95, n))
u_grid = np.dstack((u1, u2))
z_grid = sp.stats.norm.ppf(u_grid)
In [55]:
z_grid.shape
Out[55]:
(15, 15, 2)
In [63]:
x_decoded = decoder.predict(z_grid.reshape(n*n, 2))
x_decoded.shape
Out[63]:
(225, 784)
In [92]:
x_decoded = x_decoded.reshape(n, n, digit_size, digit_size)
x_decoded.shape
Out[92]:
(15, 15, 28, 28)
In [93]:
for i in range(n):
    for j in range(n):
        im[i * digit_size: (i + 1) * digit_size,
           j * digit_size: (j + 1) * digit_size] = x_decoded[i, j]
In [94]:
fig, ax = plt.subplots(figsize=(7, 7))

# ax.imshow(np.reshape(x_decoded, (28*15, 28*15), order='A'), 
#           cmap='gray')

ax.imshow(im, cmap='gray')

plt.show()