using_negative_log_likelihoods_of_tensorflow_distributions_as_keras_losses.ipynb (Source)

Preamble

In [1]:
%matplotlib notebook
In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from scipy.stats import norm

from keras import backend as K
from keras.layers import (Input, Activation, Dense, Lambda, Layer, 
                          add, multiply)
from keras.models import Model, Sequential
from keras.callbacks import TerminateOnNaN
from keras.datasets import mnist
Using TensorFlow backend.

Notebook Environment

In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
'TensorFlow version: ' + K.tf.__version__
Out[5]:
'TensorFlow version: 1.4.0'

TensorFlow Distributions

In [6]:
random_tensor = K.random_normal(shape=(5, 784), seed=42)
In [7]:
K.eval(K.sigmoid(random_tensor))
Out[7]:
array([[ 0.43,  0.47,  0.34, ...,  0.06,  0.81,  0.23],
       [ 0.21,  0.53,  0.18, ...,  0.45,  0.45,  0.58],
       [ 0.11,  0.16,  0.38, ...,  0.47,  0.35,  0.33],
       [ 0.21,  0.62,  0.43, ...,  0.55,  0.54,  0.67],
       [ 0.29,  0.31,  0.25, ...,  0.35,  0.71,  0.44]], dtype=float32)
In [8]:
K.eval(K.sigmoid(random_tensor))
Out[8]:
array([[ 0.63,  0.68,  0.14, ...,  0.28,  0.5 ,  0.28],
       [ 0.78,  0.3 ,  0.31, ...,  0.37,  0.88,  0.31],
       [ 0.57,  0.53,  0.6 , ...,  0.34,  0.4 ,  0.21],
       [ 0.61,  0.38,  0.73, ...,  0.17,  0.41,  0.27],
       [ 0.63,  0.16,  0.38, ...,  0.5 ,  0.35,  0.44]], dtype=float32)

Bernoulli log probabilities

In [9]:
random_tensor = K.random_normal(shape=(5, 784), seed=42)
pred = K.sigmoid(random_tensor)
true = K.sigmoid(random_tensor)
In [10]:
likelihood = K.tf.distributions.Bernoulli(probs=pred)
In [11]:
K.eval(K.sum(likelihood.log_prob(value=true), axis=-1))
Out[11]:
array([-468.32, -472.9 , -470.78, -471.34, -470.37], dtype=float32)

Keras binary cross-entropy loss

In [12]:
random_tensor = K.random_normal(shape=(5, 784), seed=42)
pred = K.sigmoid(random_tensor)
true = K.sigmoid(random_tensor)
In [13]:
K.eval(K.sum(K.binary_crossentropy(true, pred), axis=-1))
Out[13]:
array([ 468.32,  472.9 ,  470.78,  471.34,  470.37], dtype=float32)
In [14]:
np.allclose(-Out[11], Out[13])
Out[14]:
True

Example: Variational Autoencoder

Constant definitions
In [15]:
# input image dimensions
img_rows = 28
img_cols = 28
original_dim = img_rows * img_cols
latent_dim = 2
intermediate_dim = 128
epsilon_std = 1.0
batch_size = 100
epochs = 50

Dataset (MNIST)

In [16]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.

Model specification

In [17]:
class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl_batch = - .5 * K.sum(1 + log_var -
                                K.square(mu) -
                                K.exp(log_var), axis=-1)

        self.add_loss(K.mean(kl_batch), inputs=inputs)

        return inputs
In [18]:
def make_vae(output_activation='sigmoid'):

    x = Input(shape=(original_dim,))

    eps = Input(tensor=K.random_normal(
                    stddev=epsilon_std, 
                    shape=(K.shape(x)[0], latent_dim)))

    h = Dense(intermediate_dim, activation='relu')(x)

    z_mu = Dense(latent_dim, name='z_mean')(h)
    z_log_var = Dense(latent_dim)(h)
    z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
    z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)

    z_eps = multiply([z_sigma, eps])
    z = add([z_mu, z_eps])

    decoder = Sequential([
        Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
        Dense(original_dim, activation=output_activation)
    ], name='decoder')

    x_pred = decoder(z)

    return Model(inputs=[x, eps], outputs=x_pred, name='vae')

Model fitting

Keras binary cross-entropy loss

In [19]:
def nll1(y_true, y_pred):
    """ Negative log likelihood. """

    # keras.losses.binary_crossentropy give the mean
    # over the last axis. we require the sum
    return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [20]:
vae1 = make_vae()
vae1.compile(optimizer='rmsprop', loss=nll1)
hist1 = vae1.fit(x_train,
                 x_train,
                 shuffle=True,
                 epochs=epochs,
                 batch_size=batch_size,
                 validation_data=(x_test, x_test),
                 callbacks=[TerminateOnNaN()])
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 3s 42us/step - loss: 197.8046 - val_loss: 174.5758
Epoch 2/50
60000/60000 [==============================] - 2s 35us/step - loss: 172.2795 - val_loss: 169.8304
Epoch 3/50
60000/60000 [==============================] - 2s 39us/step - loss: 168.8145 - val_loss: 167.5593
Epoch 4/50
60000/60000 [==============================] - 2s 38us/step - loss: 166.7302 - val_loss: 165.8718
Epoch 5/50
60000/60000 [==============================] - 2s 36us/step - loss: 165.1217 - val_loss: 164.6282
Epoch 6/50
60000/60000 [==============================] - 2s 38us/step - loss: 163.8013 - val_loss: 163.4731
Epoch 7/50
60000/60000 [==============================] - 2s 36us/step - loss: 162.6845 - val_loss: 162.1845
Epoch 8/50
60000/60000 [==============================] - 2s 35us/step - loss: 161.7113 - val_loss: 161.2801
Epoch 9/50
60000/60000 [==============================] - 2s 37us/step - loss: 160.8378 - val_loss: 160.4747
Epoch 10/50
60000/60000 [==============================] - 2s 40us/step - loss: 160.0908 - val_loss: 159.8515
Epoch 11/50
60000/60000 [==============================] - 2s 38us/step - loss: 159.4220 - val_loss: 159.6488
Epoch 12/50
60000/60000 [==============================] - 2s 40us/step - loss: 158.7991 - val_loss: 158.7215
Epoch 13/50
60000/60000 [==============================] - 2s 36us/step - loss: 158.2604 - val_loss: 158.6874
Epoch 14/50
60000/60000 [==============================] - 2s 33us/step - loss: 157.8040 - val_loss: 157.8198
Epoch 15/50
60000/60000 [==============================] - 2s 35us/step - loss: 157.3699 - val_loss: 157.6631
Epoch 16/50
60000/60000 [==============================] - 2s 36us/step - loss: 157.0013 - val_loss: 157.4282
Epoch 17/50
60000/60000 [==============================] - 2s 37us/step - loss: 156.6732 - val_loss: 156.8004
Epoch 18/50
60000/60000 [==============================] - 2s 32us/step - loss: 156.3473 - val_loss: 156.9447
Epoch 19/50
60000/60000 [==============================] - 2s 34us/step - loss: 156.0587 - val_loss: 156.5810
Epoch 20/50
60000/60000 [==============================] - 2s 35us/step - loss: 155.7752 - val_loss: 156.4226
Epoch 21/50
60000/60000 [==============================] - 2s 34us/step - loss: 155.5554 - val_loss: 156.0716
Epoch 22/50
60000/60000 [==============================] - 2s 32us/step - loss: 155.2919 - val_loss: 155.9490
Epoch 23/50
60000/60000 [==============================] - 2s 33us/step - loss: 155.0792 - val_loss: 155.9058
Epoch 24/50
60000/60000 [==============================] - 2s 36us/step - loss: 154.8542 - val_loss: 155.8581
Epoch 25/50
60000/60000 [==============================] - 2s 32us/step - loss: 154.6692 - val_loss: 155.5282
Epoch 26/50
60000/60000 [==============================] - 2s 35us/step - loss: 154.4754 - val_loss: 155.5255
Epoch 27/50
60000/60000 [==============================] - 2s 35us/step - loss: 154.3106 - val_loss: 155.6988
Epoch 28/50
60000/60000 [==============================] - 2s 39us/step - loss: 154.1398 - val_loss: 155.4417
Epoch 29/50
60000/60000 [==============================] - 2s 34us/step - loss: 153.9401 - val_loss: 154.8930
Epoch 30/50
60000/60000 [==============================] - 2s 35us/step - loss: 153.7739 - val_loss: 155.1084
Epoch 31/50
60000/60000 [==============================] - 2s 35us/step - loss: 153.6338 - val_loss: 155.0252
Epoch 32/50
60000/60000 [==============================] - 3s 43us/step - loss: 153.5130 - val_loss: 154.5574
Epoch 33/50
60000/60000 [==============================] - 2s 35us/step - loss: 153.3315 - val_loss: 154.6764
Epoch 34/50
60000/60000 [==============================] - 2s 33us/step - loss: 153.2158 - val_loss: 154.3616
Epoch 35/50
60000/60000 [==============================] - 2s 32us/step - loss: 153.1154 - val_loss: 154.3324
Epoch 36/50
60000/60000 [==============================] - 2s 35us/step - loss: 152.9549 - val_loss: 154.6139
Epoch 37/50
60000/60000 [==============================] - 3s 52us/step - loss: 152.8326 - val_loss: 154.0197
Epoch 38/50
60000/60000 [==============================] - 3s 46us/step - loss: 152.7438 - val_loss: 154.0712
Epoch 39/50
60000/60000 [==============================] - 3s 45us/step - loss: 152.6157 - val_loss: 153.8970
Epoch 40/50
60000/60000 [==============================] - 2s 36us/step - loss: 152.5009 - val_loss: 153.9140
Epoch 41/50
60000/60000 [==============================] - 2s 32us/step - loss: 152.4002 - val_loss: 154.2383
Epoch 42/50
60000/60000 [==============================] - 2s 33us/step - loss: 152.2922 - val_loss: 153.5777
Epoch 43/50
60000/60000 [==============================] - 2s 33us/step - loss: 152.1851 - val_loss: 153.9621
Epoch 44/50
60000/60000 [==============================] - 3s 45us/step - loss: 152.0959 - val_loss: 153.8199
Epoch 45/50
60000/60000 [==============================] - 2s 36us/step - loss: 151.9867 - val_loss: 153.5573
Epoch 46/50
60000/60000 [==============================] - 2s 32us/step - loss: 151.9200 - val_loss: 153.9015
Epoch 47/50
60000/60000 [==============================] - 2s 32us/step - loss: 151.8279 - val_loss: 153.5111
Epoch 48/50
60000/60000 [==============================] - 2s 32us/step - loss: 151.7621 - val_loss: 153.4209
Epoch 49/50
60000/60000 [==============================] - 2s 32us/step - loss: 151.6646 - val_loss: 153.2584
Epoch 50/50
60000/60000 [==============================] - 2s 33us/step - loss: 151.5816 - val_loss: 153.3303

Numerical instability of Bernoulli with probabilities (probs)

In [21]:
def nll2(y_true, y_pred):
    """ Negative log likelihood. """

    likelihood = K.tf.distributions.Bernoulli(probs=y_pred)

    return - K.sum(likelihood.log_prob(y_true), axis=-1)
In [22]:
vae2 = make_vae()
vae2.compile(optimizer='rmsprop', loss=nll2)
hist2 = vae2.fit(x_train,
                 x_train,
                 shuffle=True,
                 epochs=epochs,
                 batch_size=batch_size,
                 validation_data=(x_test, x_test),
                 callbacks=[TerminateOnNaN()])
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 2s 41us/step - loss: 195.6972 - val_loss: 175.0890
Epoch 2/50
35700/60000 [================>.............] - ETA: 0s - loss: 174.6592Batch 368: Invalid loss, terminating training
36900/60000 [=================>............] - ETA: 0s - loss: nan     

Bernoulli with sigmoid log-odds (logits)

In [23]:
def nll3(y_true, y_pred):
    """ Negative log likelihood. """

    likelihood = K.tf.distributions.Bernoulli(logits=y_pred)

    return - K.sum(likelihood.log_prob(y_true), axis=-1)
In [24]:
vae3 = make_vae(output_activation=None)
vae3.compile(optimizer='rmsprop', loss=nll3)
hist3 = vae3.fit(x_train,
                 x_train,
                 shuffle=True,
                 epochs=epochs,
                 batch_size=batch_size,
                 validation_data=(x_test, x_test),
                 callbacks=[TerminateOnNaN()])
Train on 60000 samples, validate on 10000 samples
Epoch 1/50
60000/60000 [==============================] - 2s 39us/step - loss: 197.3443 - val_loss: 174.8810
Epoch 2/50
60000/60000 [==============================] - 2s 38us/step - loss: 171.7976 - val_loss: 170.0754
Epoch 3/50
60000/60000 [==============================] - 2s 41us/step - loss: 168.7622 - val_loss: 168.0805
Epoch 4/50
60000/60000 [==============================] - 2s 36us/step - loss: 166.9539 - val_loss: 166.3837
Epoch 5/50
60000/60000 [==============================] - 3s 46us/step - loss: 165.3324 - val_loss: 165.0904
Epoch 6/50
60000/60000 [==============================] - 2s 37us/step - loss: 163.9750 - val_loss: 163.6643
Epoch 7/50
60000/60000 [==============================] - 3s 45us/step - loss: 162.9434 - val_loss: 163.1811
Epoch 8/50
60000/60000 [==============================] - 2s 35us/step - loss: 162.1413 - val_loss: 162.0859
Epoch 9/50
60000/60000 [==============================] - 2s 39us/step - loss: 161.5138 - val_loss: 161.5494
Epoch 10/50
60000/60000 [==============================] - 3s 46us/step - loss: 160.9495 - val_loss: 161.1756
Epoch 11/50
60000/60000 [==============================] - 2s 39us/step - loss: 160.4728 - val_loss: 160.6088
Epoch 12/50
60000/60000 [==============================] - 2s 40us/step - loss: 160.0173 - val_loss: 160.2232
Epoch 13/50
60000/60000 [==============================] - 2s 38us/step - loss: 159.5693 - val_loss: 159.9956
Epoch 14/50
60000/60000 [==============================] - 2s 38us/step - loss: 159.0975 - val_loss: 159.5913
Epoch 15/50
60000/60000 [==============================] - 2s 41us/step - loss: 158.7159 - val_loss: 159.1639
Epoch 16/50
60000/60000 [==============================] - 2s 40us/step - loss: 158.3084 - val_loss: 158.7390
Epoch 17/50
60000/60000 [==============================] - 2s 39us/step - loss: 157.9626 - val_loss: 158.3853
Epoch 18/50
60000/60000 [==============================] - 2s 38us/step - loss: 157.6130 - val_loss: 158.0789
Epoch 19/50
60000/60000 [==============================] - 2s 35us/step - loss: 157.2902 - val_loss: 157.9533
Epoch 20/50
60000/60000 [==============================] - 2s 37us/step - loss: 156.9979 - val_loss: 157.8502
Epoch 21/50
60000/60000 [==============================] - 2s 37us/step - loss: 156.6903 - val_loss: 157.6366
Epoch 22/50
60000/60000 [==============================] - 2s 39us/step - loss: 156.4382 - val_loss: 157.5489
Epoch 23/50
60000/60000 [==============================] - 2s 41us/step - loss: 156.1680 - val_loss: 156.9574
Epoch 24/50
60000/60000 [==============================] - 2s 41us/step - loss: 155.9509 - val_loss: 156.8963
Epoch 25/50
60000/60000 [==============================] - 2s 39us/step - loss: 155.7179 - val_loss: 156.6620
Epoch 26/50
60000/60000 [==============================] - 3s 42us/step - loss: 155.5165 - val_loss: 156.3996
Epoch 27/50
60000/60000 [==============================] - 3s 43us/step - loss: 155.2751 - val_loss: 156.3764
Epoch 28/50
60000/60000 [==============================] - 3s 45us/step - loss: 155.0779 - val_loss: 156.0982
Epoch 29/50
60000/60000 [==============================] - 3s 44us/step - loss: 154.9262 - val_loss: 156.1363
Epoch 30/50
60000/60000 [==============================] - 3s 43us/step - loss: 154.7112 - val_loss: 155.9733
Epoch 31/50
60000/60000 [==============================] - 3s 43us/step - loss: 154.5473 - val_loss: 155.5940
Epoch 32/50
60000/60000 [==============================] - 3s 42us/step - loss: 154.3638 - val_loss: 155.5866
Epoch 33/50
60000/60000 [==============================] - 3s 43us/step - loss: 154.1933 - val_loss: 155.5409
Epoch 34/50
60000/60000 [==============================] - 3s 47us/step - loss: 154.0451 - val_loss: 155.0881
Epoch 35/50
60000/60000 [==============================] - 3s 46us/step - loss: 153.8880 - val_loss: 155.1717
Epoch 36/50
60000/60000 [==============================] - 3s 46us/step - loss: 153.7204 - val_loss: 155.1469
Epoch 37/50
60000/60000 [==============================] - 2s 41us/step - loss: 153.5952 - val_loss: 155.1155
Epoch 38/50
60000/60000 [==============================] - 2s 40us/step - loss: 153.4529 - val_loss: 154.7922
Epoch 39/50
60000/60000 [==============================] - 2s 39us/step - loss: 153.3160 - val_loss: 154.8435
Epoch 40/50
60000/60000 [==============================] - 3s 42us/step - loss: 153.1796 - val_loss: 154.7323
Epoch 41/50
60000/60000 [==============================] - 3s 44us/step - loss: 153.0591 - val_loss: 154.4354
Epoch 42/50
60000/60000 [==============================] - 2s 39us/step - loss: 152.9476 - val_loss: 154.5590
Epoch 43/50
60000/60000 [==============================] - 2s 39us/step - loss: 152.8587 - val_loss: 154.3976
Epoch 44/50
60000/60000 [==============================] - 2s 40us/step - loss: 152.7007 - val_loss: 154.4268
Epoch 45/50
60000/60000 [==============================] - 2s 39us/step - loss: 152.6327 - val_loss: 154.2704
Epoch 46/50
60000/60000 [==============================] - 2s 39us/step - loss: 152.5104 - val_loss: 153.9890
Epoch 47/50
60000/60000 [==============================] - 3s 46us/step - loss: 152.3909 - val_loss: 154.0543
Epoch 48/50
60000/60000 [==============================] - 3s 46us/step - loss: 152.2902 - val_loss: 153.9199
Epoch 49/50
60000/60000 [==============================] - 3s 44us/step - loss: 152.2137 - val_loss: 154.0787
Epoch 50/50
60000/60000 [==============================] - 3s 43us/step - loss: 152.1108 - val_loss: 153.9863

Model Evaluation

In [25]:
def get_encoder(vae):

    x, _ = vae.inputs
    z_mu = vae.get_layer('z_mean').output

    return Model(x, z_mu)
In [26]:
def get_decoder(vae):
    
    decoder = Sequential([
        vae.get_layer('decoder'),
        Activation('sigmoid')
    ])

    return decoder
In [27]:
encoder = get_encoder(vae3)
decoder = get_decoder(vae3)

NELBO

In [28]:
golden_size = lambda width: (width, 2. * width / (1 + np.sqrt(5)))
In [29]:
fig, ax = plt.subplots(figsize=golden_size(6))

hist_df = pd.DataFrame(hist3.history)
hist_df.plot(ax=ax)

ax.set_ylabel('NELBO')
ax.set_xlabel('# epochs')

plt.show()

Observed space manifold

In [30]:
# display a 2D manifold of the images
n = 15  # figure with 15x15 images
digit_size = 28
quantile_min = 0.01
quantile_max = 0.99

# linearly spaced coordinates on the unit square were transformed
# through the inverse CDF (ppf) of the Gaussian to produce values
# of the latent variables z, since the prior of the latent space
# is Gaussian

z1 = norm.ppf(np.linspace(quantile_min, quantile_max, n))
z2 = norm.ppf(np.linspace(quantile_max, quantile_min, n))
z_grid = np.dstack(np.meshgrid(z1, z2))
In [31]:
x_pred_grid = decoder.predict(z_grid.reshape(n*n, latent_dim)) \
                     .reshape(n, n, img_rows, img_cols)
In [32]:
fig, ax = plt.subplots(figsize=(4, 4))

ax.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')

ax.set_xticks(np.arange(0, n*img_rows, img_rows) + .5 * img_rows)
ax.set_xticklabels(map('{:.2f}'.format, z1), rotation=90)

ax.set_yticks(np.arange(0, n*img_cols, img_cols) + .5 * img_cols)
ax.set_yticklabels(map('{:.2f}'.format, z2))

ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')

plt.show()
In [33]:
# display a 2D plot of the digit classes in the latent space
z_test = encoder.predict(x_test, batch_size=batch_size)
In [34]:
fig, ax = plt.subplots(figsize=(6, 5))

cbar = ax.scatter(z_test[:, 0], z_test[:, 1], c=y_test,
                   alpha=.4, s=3**2, cmap='viridis')
fig.colorbar(cbar, ax=ax)

ax.set_xlim(2.*norm.ppf((quantile_min, quantile_max)))
ax.set_ylim(2.*norm.ppf((quantile_min, quantile_max)))

ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')

plt.show()
In [35]:
fig, (ax1, ax2) = plt.subplots(ncols=2, figsize=(12, 4.5))

ax1.imshow(np.block(list(map(list, x_pred_grid))), cmap='gray')

ax1.set_xticks(np.arange(0, n*img_rows, img_rows) + .5 * img_rows)
ax1.set_xticklabels(map('{:.2f}'.format, z1), rotation=90)

ax1.set_yticks(np.arange(0, n*img_cols, img_cols) + .5 * img_cols)
ax1.set_yticklabels(map('{:.2f}'.format, z2))

ax.set_xlabel('$z_1$')
ax.set_ylabel('$z_2$')

cbar = ax2.scatter(z_test[:, 0], z_test[:, 1], c=y_test,
                   alpha=.4, s=3**2, cmap='viridis')
fig.colorbar(cbar, ax=ax2)

ax2.set_xlim(norm.ppf((quantile_min, quantile_max)))
ax2.set_ylim(norm.ppf((quantile_min, quantile_max)))

ax2.set_xlabel('$z_1$')
ax2.set_ylabel('$z_2$')

plt.show()