variational_autoencoder-checkpoint.ipynb (Source)
Preamble¶
In [1]:
%matplotlib notebook
In [2]:
import numpy as np
from scipy.stats import norm
import matplotlib.pyplot as plt
from keras import backend as K
from keras.models import Model, Sequential
from keras.layers import (Dense, Input, Layer, Lambda,
Add, Multiply)
from keras.datasets import mnist
from keras.utils.vis_utils import model_to_dot, plot_model
from IPython.display import SVG
Notebook Configuration¶
In [3]:
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
In [4]:
'TensorFlow version: ' + K.tf.__version__
Out[4]:
Constant definitions¶
In [5]:
mc_samples = 5
batch_size = 100
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0
Reparameterization¶
noise as auxiliary input to the network
In [6]:
x = Input(shape=(original_dim,), name='x')
In [7]:
h = Dense(intermediate_dim, activation='relu', name='hidden')(x)
In [8]:
z_mu = Dense(latent_dim, name='mu')(h)
z_log_var = Dense(latent_dim, name='log_var')(h)
In [9]:
class KLDivergenceLayer(Layer):
def __init__(self, *args, **kwargs):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def call(self, inputs):
mu, log_var = inputs
kl = - .5 * K.sum(1 + K.print_tensor(log_var, message='prick')
- K.square(mu)
- K.exp(log_var), axis=-1)
self.add_loss(kl, inputs=inputs)
return inputs
In [10]:
z_mu, z_log_var = KLDivergenceLayer(name='kl')([z_mu, z_log_var])
In [11]:
sigma = Lambda(lambda x: K.exp(.5*x), name='sigma')(z_log_var)
In [49]:
# z_mean = Input(shape=(latent_dim,), name='mu')
# z_std_dev = Input(shape=(latent_dim,), name='sigma')
# eps = Input(shape=(mc_samples, latent_dim), name='eps')
In [36]:
# z_eps = Multiply(name='z_eps')([z_std_dev, eps])
# z = Add(name='z')([z_mean, z_eps])
In [37]:
# m = Model(inputs=[eps, z_mean, z_std_dev], outputs=z)
In [38]:
# SVG(model_to_dot(m, show_shapes=False).create(prog='dot', format='svg'))
Out[38]:
In [41]:
# plot_model(
# model=m, show_shapes=False,
# to_file='../images/vae/reparameterization.svg'
# )
In [42]:
# plot_model(
# model=m, show_shapes=True,
# to_file='../images/vae/reparameterization_shapes.svg'
# )
In [13]:
# eps = Input(shape=(n_samples, latent_dim,), name='epsilon')
eps = Input(shape=(mc_samples, latent_dim), name='epsilon')
sigma_eps = Multiply(name='sigma_eps')([sigma, eps])
z = Add(name='z')([z_mu, sigma_eps])
In [14]:
encoder = Model(inputs=[x, eps], outputs=z)
In [15]:
SVG(model_to_dot(encoder, show_shapes=True)
.create(prog='dot', format='svg'))
Out[15]:
Decoder¶
In [16]:
# decoder = Sequential([
# Dense(intermediate_dim, activation='relu', input_dim=latent_dim),
# Dense(original_dim, activation='sigmoid')
# ], name='decoder')
In [17]:
# x_decoded_mean = decoder(z)
In [18]:
# SVG(model_to_dot(decoder, show_shapes=True)
# .create(prog='dot', format='svg'))
In [26]:
decoder_h = Dense(intermediate_dim, activation='relu')
decoder_mean = Dense(original_dim, activation='sigmoid')
In [27]:
h_decoded = decoder_h(z)
x_decoded_mean = decoder_mean(h_decoded)
In [39]:
def nll(y_true, y_pred):
""" Negative log likelihood. """
# keras.losses.binary_crossentropy give the mean
# over the last axis. we require the sum
return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [40]:
vae = Model(inputs=[x, eps], outputs=x_decoded_mean)
vae.compile(optimizer='rmsprop', loss=nll)
In [41]:
SVG(model_to_dot(vae, show_shapes=True)
.create(prog='dot', format='svg'))
Out[41]:
In [42]:
(X_train, y_train), (X_test, y_test) = mnist.load_data()
X_train = X_train.reshape(-1, 28*28) / 255.
X_test = X_test.reshape(-1, 28*28) / 255.
In [43]:
vae.evaluate(
[X_train, np.random.randn(len(X_train), mc_samples, latent_dim)],
X_train,
batch_size=batch_size
)
In [24]:
vae.fit([X_train, np.random.randn(len(X_train), latent_dim)],
X_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(
[X_test, np.random.randn(len(X_test), latent_dim)],
X_test))
Out[24]:
In [38]:
X_test_encoded = encoder.predict([X_test, np.random.randn(len(X_train), latent_dim)])
In [91]:
fig, ax = plt.subplots(figsize=(6, 5))
cbar = ax.scatter(X_test_encoded[:, 0], X_test_encoded[:, 1],
c=y_test, alpha=.4, s=3**2,
cmap='viridis')
fig.colorbar(cbar, ax=ax)
plt.show()
In [40]:
n = 15 # figure with 15x15 digits
digit_size = 28
im = np.zeros((digit_size * n, digit_size * n))
In [53]:
# linearly spaced coordinates on the unit square were
# transformed through the inverse CDF (ppf) of the Gaussian
# to produce values of the latent variables z, since the
# prior of the latent space is Gaussian
u1, u2 = np.meshgrid(np.linspace(0.05, 0.95, n),
np.linspace(0.05, 0.95, n))
u_grid = np.dstack((u1, u2))
z_grid = sp.stats.norm.ppf(u_grid)
In [55]:
z_grid.shape
Out[55]:
In [63]:
x_decoded = decoder.predict(z_grid.reshape(n*n, 2))
x_decoded.shape
Out[63]:
In [92]:
x_decoded = x_decoded.reshape(n, n, digit_size, digit_size)
x_decoded.shape
Out[92]:
In [93]:
for i in range(n):
for j in range(n):
im[i * digit_size: (i + 1) * digit_size,
j * digit_size: (j + 1) * digit_size] = x_decoded[i, j]
In [94]:
fig, ax = plt.subplots(figsize=(7, 7))
# ax.imshow(np.reshape(x_decoded, (28*15, 28*15), order='A'),
# cmap='gray')
ax.imshow(im, cmap='gray')
plt.show()