using_negative_log_likelihoods_of_tensorflow_distributions_as_keras_losses.ipynb (Source)
Preamble¶
In [1]:
%matplotlib notebook
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from scipy.stats import norm
from keras import backend as K
from keras.layers import (Input, Activation, Dense, Lambda, Layer,
add, multiply)
from keras.models import Model, Sequential
from keras.callbacks import TerminateOnNaN
from keras.datasets import mnist
Notebook Environment¶
In [4]:
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
In [5]:
'TensorFlow version: ' + K.tf.__version__
Out[5]:
TensorFlow Distributions¶
In [6]:
random_tensor = K.random_normal(shape=(5, 784), seed=42)
In [7]:
K.eval(K.sigmoid(random_tensor))
Out[7]:
In [8]:
K.eval(K.sigmoid(random_tensor))
Out[8]:
Bernoulli log probabilities¶
In [9]:
random_tensor = K.random_normal(shape=(5, 784), seed=42)
pred = K.sigmoid(random_tensor)
true = K.sigmoid(random_tensor)
In [10]:
likelihood = K.tf.distributions.Bernoulli(probs=pred)
In [11]:
K.eval(K.sum(likelihood.log_prob(value=true), axis=-1))
Out[11]:
Keras binary cross-entropy loss¶
In [12]:
random_tensor = K.random_normal(shape=(5, 784), seed=42)
pred = K.sigmoid(random_tensor)
true = K.sigmoid(random_tensor)
In [13]:
K.eval(K.sum(K.binary_crossentropy(true, pred), axis=-1))
Out[13]:
In [14]:
np.allclose(-Out[11], Out[13])
Out[14]:
Example: Variational Autoencoder¶
Constant definitions¶
In [15]:
# input image dimensions
img_rows = 28
img_cols = 28
original_dim = img_rows * img_cols
latent_dim = 2
intermediate_dim = 128
epsilon_std = 1.0
batch_size = 100
epochs = 50
Dataset (MNIST)¶
In [16]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.
Model specification¶
In [17]:
class KLDivergenceLayer(Layer):
""" Identity transform layer that adds KL divergence
to the final model loss.
"""
def __init__(self, *args, **kwargs):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def call(self, inputs):
mu, log_var = inputs
kl_batch = - .5 * K.sum(1 + log_var -
K.square(mu) -
K.exp(log_var), axis=-1)
self.add_loss(K.mean(kl_batch), inputs=inputs)
return inputs
In [18]:
def make_vae(output_activation='sigmoid'):
x = Input(shape=(original_dim,))
eps = Input(tensor=K.random_normal(
stddev=epsilon_std,
shape=(K.shape(x)[0], latent_dim)))
h = Dense(intermediate_dim, activation='relu')(x)
z_mu = Dense(latent_dim, name='z_mean')(h)
z_log_var = Dense(latent_dim)(h)
z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)
z_eps = multiply([z_sigma, eps])
z = add([z_mu, z_eps])
decoder = Sequential([
Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
Dense(original_dim, activation=output_activation)
], name='decoder')
x_pred = decoder(z)
return Model(inputs=[x, eps], outputs=x_pred, name='vae')
Model fitting¶
Keras binary cross-entropy loss¶
In [19]:
def nll1(y_true, y_pred):
""" Negative log likelihood. """
# keras.losses.binary_crossentropy give the mean
# over the last axis. we require the sum
return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [20]:
vae1 = make_vae()
vae1.compile(optimizer='rmsprop', loss=nll1)
hist1 = vae1.fit(x_train,
x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, x_test),
callbacks=[TerminateOnNaN()])
Numerical instability of Bernoulli with probabilities (probs
)¶
In [21]:
def nll2(y_true, y_pred):
""" Negative log likelihood. """
likelihood = K.tf.distributions.Bernoulli(probs=y_pred)
return - K.sum(likelihood.log_prob(y_true), axis=-1)
In [22]:
vae2 = make_vae()
vae2.compile(optimizer='rmsprop', loss=nll2)
hist2 = vae2.fit(x_train,
x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, x_test),
callbacks=[TerminateOnNaN()])
Bernoulli with sigmoid log-odds (logits
)¶
In [23]:
def nll3(y_true, y_pred):
""" Negative log likelihood. """
likelihood = K.tf.distributions.Bernoulli(logits=y_pred)
return - K.sum(likelihood.log_prob(y_true), axis=-1)
In [24]:
vae3 = make_vae(output_activation=None)
vae3.compile(optimizer='rmsprop', loss=nll3)
hist3 = vae3.fit(x_train,
x_train,
shuffle=True,
epochs=epochs,
batch_size=batch_size,
validation_data=(x_test, x_test),
callbacks=[TerminateOnNaN()])
Model Evaluation¶
In [25]:
def get_encoder(vae):
x, _ = vae.inputs
z_mu = vae.get_layer('z_mean').output
return Model(x, z_mu)
In [26]:
def get_decoder(vae):
decoder = Sequential([
vae.get_layer('decoder'),
Activation('sigmoid')
])
return decoder
In [27]:
encoder = get_encoder(vae3)
decoder = get_decoder(vae3)
NELBO¶
In [28]:
golden_size = lambda width: (width, 2. * width / (1 + np.sqrt(5)))
In [29]:
fig, ax = plt.subplots(figsize=golden_size(6))
hist_df = pd.DataFrame(hist3.history)
hist_df.plot(ax=ax)
ax.set_ylabel('NELBO')
ax.set_xlabel('# epochs')
plt.show()
Observed space manifold¶
In [30]:
# display a 2D manifold of the images
n = 15 # figure with 15x15 images
digit_size = 28
quantile_min = 0.01
quantile_max = 0.99
# linearly spaced coordinates on the unit square were transformed
# through the inverse CDF (ppf) of the Gaussian to produce values
# of the latent variables z, since the prior of the latent space
# is Gaussian
z1 = norm.ppf(np.linspace(quantile_min, quantile_max, n))
z2 = norm.ppf(np.linspace(quantile_max, quantile_min, n))
z_grid = np.dstack(np.meshgrid(z1, z2))
In [31]:
x_pred_grid = decoder.predict(z_grid.reshape(n*n, latent_dim)) \
.reshape(n, n, img_rows, img_cols)
In [32]:
fig, ax = plt.subplots(figsize=(4, 4))
ax.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')
ax.set_xticks(np.arange(0, n*img_rows, img_rows) + .5 * img_rows)
ax.set_xticklabels(map('{:.2f}'.format, z1), rotation=90)
ax.set_yticks(np.arange(0, n*img_cols, img_cols) + .5 * img_cols)
ax.set_yticklabels(map('{:.2f}'.format, z2))
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
plt.show()
In [33]:
# display a 2D plot of the digit classes in the latent space
z_test = encoder.predict(x_test, batch_size=batch_size)
In [34]:
fig, ax = plt.subplots(figsize=(6, 5))
cbar = ax.scatter(z_test[:, 0], z_test[:, 1], c=y_test,
alpha=.4, s=3**2, cmap='viridis')
fig.colorbar(cbar, ax=ax)
ax.set_xlim(2.*norm.ppf((quantile_min, quantile_max)))
ax.set_ylim(2.*norm.ppf((quantile_min, quantile_max)))
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
plt.show()
In [35]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 4.5))
ax1.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')
ax1.set_xticks(np.arange(0, n*img_rows, img_rows) + .5 * img_rows)
ax1.set_xticklabels(map('{:.2f}'.format, z1), rotation=90)
ax1.set_yticks(np.arange(0, n*img_cols, img_cols) + .5 * img_cols)
ax1.set_yticklabels(map('{:.2f}'.format, z2))
ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')
cbar = ax2.scatter(z_test[:, 0], z_test[:, 1], c=y_test,
alpha=.4, s=3**2, cmap='viridis')
fig.colorbar(cbar, ax=ax2)
ax2.set_xlim(norm.ppf((quantile_min, quantile_max)))
ax2.set_ylim(norm.ppf((quantile_min, quantile_max)))
ax2.set_xlabel('$z_1$')
ax2.set_ylabel('$z_2$')
plt.show()