variational_autoencoder_mc_samples_grid-checkpoint.ipynb (Source)

Preamble

In [1]:
%matplotlib notebook
In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras import backend as K

from keras.layers import (Input, InputLayer, Dense, Lambda, Layer, 
                          Add, Multiply)
from keras.models import Model, Sequential
from keras.datasets import mnist
Using TensorFlow backend.
In [20]:
import pandas as pd

from functools import partial
from matplotlib.ticker import FormatStrFormatter
from keras.utils.vis_utils import model_to_dot, plot_model
from IPython.display import SVG

Notebook Configuration

In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
'TensorFlow version: ' + K.tf.__version__
Out[5]:
'TensorFlow version: 1.4.0'
Constant definitions
In [6]:
mc_sample_sizes = [1, 5, 25]

original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0
In [7]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.
In [8]:
def nll(y_true, y_pred):
    """ Bernoulli negative log likelihood. """

    # keras.losses.binary_crossentropy gives the mean
    # over the last axis. We require the sum.
    return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [9]:
class KLDivergenceLayer(Layer):

    """ Identity transform layer that adds KL divergence
    to the final model loss.
    """

    def __init__(self, *args, **kwargs):
        self.is_placeholder = True
        super(KLDivergenceLayer, self).__init__(*args, **kwargs)

    def call(self, inputs):

        mu, log_var = inputs

        kl_batch = - .5 * K.sum(1 + log_var -
                                K.square(mu) -
                                K.exp(log_var), axis=-1)

        self.add_loss(K.mean(kl_batch), inputs=inputs)

        return inputs
In [10]:
def build_vae(mc_sample_size, original_dim, latent_dim, intermediate_dim):

    x = Input(shape=(original_dim,))
    h = Dense(intermediate_dim, activation='relu')(x)

    z_mu = Dense(latent_dim)(h)
    z_log_var = Dense(latent_dim)(h)

    z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
    z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)

    eps = Input(tensor=K.random_normal(stddev=epsilon_std,
                                       shape=(K.shape(x)[0],
                                              mc_sample_size,
                                              latent_dim)))

    z_eps = Multiply()([z_sigma, eps])
    z = Add()([z_mu, z_eps])

    decoder = Sequential([
        Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
        Dense(original_dim, activation='sigmoid')
    ])

    x_mean = decoder(z)

    return Model(inputs=[x, eps], outputs=x_mean)
In [16]:
def fit_history(batch_size, mc_sample_size):

    print('batch size {} | MC sample size {}'
          .format(batch_size, mc_sample_size))

    x_train_target = np.tile(np.expand_dims(x_train, axis=1),
                             reps=(1, mc_sample_size, 1))
    x_test_target = np.tile(np.expand_dims(x_test, axis=1),
                            reps=(1, mc_sample_size, 1))

    vae = build_vae(mc_sample_size, original_dim, latent_dim, 
                    intermediate_dim)
    vae.compile(optimizer='rmsprop', loss=nll)

    return vae.fit(x_train,
                   x_train_target,
                   shuffle=True,
                   epochs=epochs,
                   batch_size=batch_size)
In [29]:
def plot_fit_history(h, batch_size, mc_sample_size, ax=None):

    if ax is None:
        ax = plt.gca()

    pd.DataFrame(h.history).plot(ax=ax, y='loss', 
        label='MC samples: {:d}'.format(mc_sample_size))

    ax.set_title('Batch size: {:d}'.format(batch_size))
    ax.set_ylabel('NELBO')
    ax.set_xlabel('# epochs')
In [18]:
golden_figsize = lambda width: (width, 2. * width / (1 + np.sqrt(5)))
In [25]:
histories1 = list(map(partial(fit_history, 100), mc_sample_sizes))
batch size 100 | MC sample size 1
Epoch 1/50
60000/60000 [==============================] - 2s 40us/step - loss: 189.0387
Epoch 2/50
60000/60000 [==============================] - 2s 32us/step - loss: 169.0337
Epoch 3/50
60000/60000 [==============================] - 2s 31us/step - loss: 165.6830
Epoch 4/50
60000/60000 [==============================] - 2s 31us/step - loss: 163.4543
Epoch 5/50
60000/60000 [==============================] - 2s 31us/step - loss: 161.6128
Epoch 6/50
60000/60000 [==============================] - 2s 32us/step - loss: 160.1299
Epoch 7/50
60000/60000 [==============================] - 2s 32us/step - loss: 158.9801
Epoch 8/50
60000/60000 [==============================] - 2s 32us/step - loss: 158.1048
Epoch 9/50
60000/60000 [==============================] - 2s 32us/step - loss: 157.3637
Epoch 10/50
60000/60000 [==============================] - 2s 32us/step - loss: 156.7864
Epoch 11/50
60000/60000 [==============================] - 2s 32us/step - loss: 156.2620
Epoch 12/50
60000/60000 [==============================] - 2s 31us/step - loss: 155.7970
Epoch 13/50
60000/60000 [==============================] - 2s 32us/step - loss: 155.4058
Epoch 14/50
60000/60000 [==============================] - 2s 30us/step - loss: 155.0614
Epoch 15/50
60000/60000 [==============================] - 2s 31us/step - loss: 154.7134
Epoch 16/50
60000/60000 [==============================] - 2s 31us/step - loss: 154.4399
Epoch 17/50
60000/60000 [==============================] - 2s 31us/step - loss: 154.1532
Epoch 18/50
60000/60000 [==============================] - 2s 31us/step - loss: 153.9101
Epoch 19/50
60000/60000 [==============================] - 2s 31us/step - loss: 153.6731
Epoch 20/50
60000/60000 [==============================] - 2s 31us/step - loss: 153.4678
Epoch 21/50
60000/60000 [==============================] - 2s 31us/step - loss: 153.2793
Epoch 22/50
60000/60000 [==============================] - 2s 31us/step - loss: 153.0618
Epoch 23/50
60000/60000 [==============================] - 2s 31us/step - loss: 152.8981
Epoch 24/50
60000/60000 [==============================] - 2s 31us/step - loss: 152.7158
Epoch 25/50
60000/60000 [==============================] - 2s 31us/step - loss: 152.5932
Epoch 26/50
60000/60000 [==============================] - 2s 31us/step - loss: 152.4110
Epoch 27/50
60000/60000 [==============================] - 2s 31us/step - loss: 152.2653
Epoch 28/50
60000/60000 [==============================] - 2s 31us/step - loss: 152.1090
Epoch 29/50
60000/60000 [==============================] - 2s 31us/step - loss: 151.9918
Epoch 30/50
60000/60000 [==============================] - 2s 31us/step - loss: 151.8823
Epoch 31/50
60000/60000 [==============================] - 2s 31us/step - loss: 151.7416
Epoch 32/50
60000/60000 [==============================] - 2s 31us/step - loss: 151.6229
Epoch 33/50
60000/60000 [==============================] - 2s 31us/step - loss: 151.5112
Epoch 34/50
60000/60000 [==============================] - 2s 31us/step - loss: 151.4135
Epoch 35/50
60000/60000 [==============================] - 2s 30us/step - loss: 151.3030
Epoch 36/50
60000/60000 [==============================] - 2s 30us/step - loss: 151.1725
Epoch 37/50
60000/60000 [==============================] - 2s 30us/step - loss: 151.0907
Epoch 38/50
60000/60000 [==============================] - 2s 30us/step - loss: 150.9281
Epoch 39/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.8562
Epoch 40/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.7535
Epoch 41/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.6494
Epoch 42/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.5700
Epoch 43/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.4739
Epoch 44/50
60000/60000 [==============================] - 2s 32us/step - loss: 150.3871
Epoch 45/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.2817
Epoch 46/50
60000/60000 [==============================] - 2s 31us/step - loss: 150.1927
Epoch 47/50
60000/60000 [==============================] - 2s 30us/step - loss: 150.1212
Epoch 48/50
60000/60000 [==============================] - 2s 30us/step - loss: 149.9925
Epoch 49/50
60000/60000 [==============================] - 2s 30us/step - loss: 149.9569
Epoch 50/50
60000/60000 [==============================] - 2s 30us/step - loss: 149.8732
batch size 100 | MC sample size 5
Epoch 1/50
60000/60000 [==============================] - 4s 62us/step - loss: 190.1918
Epoch 2/50
60000/60000 [==============================] - 3s 53us/step - loss: 169.5446
Epoch 3/50
60000/60000 [==============================] - 3s 54us/step - loss: 166.2662
Epoch 4/50
60000/60000 [==============================] - 3s 53us/step - loss: 164.0870
Epoch 5/50
60000/60000 [==============================] - 3s 54us/step - loss: 162.5137
Epoch 6/50
60000/60000 [==============================] - 3s 53us/step - loss: 161.3143
Epoch 7/50
60000/60000 [==============================] - 3s 53us/step - loss: 160.2275
Epoch 8/50
60000/60000 [==============================] - 3s 53us/step - loss: 159.1967
Epoch 9/50
60000/60000 [==============================] - 3s 53us/step - loss: 158.1983
Epoch 10/50
60000/60000 [==============================] - 3s 53us/step - loss: 157.3319
Epoch 11/50
60000/60000 [==============================] - 3s 54us/step - loss: 156.5584
Epoch 12/50
60000/60000 [==============================] - 3s 54us/step - loss: 155.9154
Epoch 13/50
60000/60000 [==============================] - 3s 54us/step - loss: 155.3372
Epoch 14/50
60000/60000 [==============================] - 3s 54us/step - loss: 154.8555
Epoch 15/50
60000/60000 [==============================] - 3s 54us/step - loss: 154.4196
Epoch 16/50
60000/60000 [==============================] - 3s 53us/step - loss: 154.0410
Epoch 17/50
60000/60000 [==============================] - 3s 55us/step - loss: 153.6830
Epoch 18/50
60000/60000 [==============================] - 3s 55us/step - loss: 153.3692
Epoch 19/50
60000/60000 [==============================] - 3s 55us/step - loss: 153.0748
Epoch 20/50
60000/60000 [==============================] - 3s 55us/step - loss: 152.8088
Epoch 21/50
60000/60000 [==============================] - 3s 54us/step - loss: 152.5538
Epoch 22/50
60000/60000 [==============================] - 3s 53us/step - loss: 152.2983
Epoch 23/50
60000/60000 [==============================] - 3s 53us/step - loss: 152.0915
Epoch 24/50
60000/60000 [==============================] - 3s 53us/step - loss: 151.8647
Epoch 25/50
60000/60000 [==============================] - 3s 54us/step - loss: 151.6694
Epoch 26/50
60000/60000 [==============================] - 3s 53us/step - loss: 151.4926
Epoch 27/50
60000/60000 [==============================] - 3s 53us/step - loss: 151.3275
Epoch 28/50
60000/60000 [==============================] - 3s 53us/step - loss: 151.1350
Epoch 29/50
60000/60000 [==============================] - 3s 53us/step - loss: 151.0158
Epoch 30/50
60000/60000 [==============================] - 3s 53us/step - loss: 150.8242
Epoch 31/50
60000/60000 [==============================] - 3s 53us/step - loss: 150.6936
Epoch 32/50
60000/60000 [==============================] - 3s 54us/step - loss: 150.5376
Epoch 33/50
60000/60000 [==============================] - 3s 54us/step - loss: 150.4095
Epoch 34/50
60000/60000 [==============================] - 3s 53us/step - loss: 150.2751
Epoch 35/50
60000/60000 [==============================] - 3s 53us/step - loss: 150.1579
Epoch 36/50
60000/60000 [==============================] - 3s 53us/step - loss: 150.0314
Epoch 37/50
60000/60000 [==============================] - 3s 54us/step - loss: 149.9030
Epoch 38/50
60000/60000 [==============================] - 3s 53us/step - loss: 149.7720
Epoch 39/50
60000/60000 [==============================] - 3s 54us/step - loss: 149.6809
Epoch 40/50
60000/60000 [==============================] - 3s 53us/step - loss: 149.5872
Epoch 41/50
60000/60000 [==============================] - 3s 53us/step - loss: 149.4470
Epoch 42/50
60000/60000 [==============================] - 3s 53us/step - loss: 149.3811
Epoch 43/50
60000/60000 [==============================] - 3s 54us/step - loss: 149.2661
Epoch 44/50
60000/60000 [==============================] - 3s 53us/step - loss: 149.1528
Epoch 45/50
60000/60000 [==============================] - 3s 53us/step - loss: 149.0910
Epoch 46/50
60000/60000 [==============================] - 3s 53us/step - loss: 148.9832
Epoch 47/50
60000/60000 [==============================] - 3s 54us/step - loss: 148.8849
Epoch 48/50
60000/60000 [==============================] - 3s 55us/step - loss: 148.7902
Epoch 49/50
60000/60000 [==============================] - 3s 53us/step - loss: 148.6930
Epoch 50/50
60000/60000 [==============================] - 3s 53us/step - loss: 148.6199
batch size 100 | MC sample size 25
Epoch 1/50
60000/60000 [==============================] - 11s 176us/step - loss: 190.2488
Epoch 2/50
60000/60000 [==============================] - 10s 170us/step - loss: 169.3187
Epoch 3/50
60000/60000 [==============================] - 10s 169us/step - loss: 166.2625
Epoch 4/50
60000/60000 [==============================] - 10s 169us/step - loss: 164.1589
Epoch 5/50
60000/60000 [==============================] - 10s 170us/step - loss: 162.6340
Epoch 6/50
60000/60000 [==============================] - 10s 167us/step - loss: 161.3680
Epoch 7/50
60000/60000 [==============================] - 10s 169us/step - loss: 160.2168
Epoch 8/50
60000/60000 [==============================] - 10s 169us/step - loss: 159.1401
Epoch 9/50
60000/60000 [==============================] - 10s 173us/step - loss: 158.1189
Epoch 10/50
60000/60000 [==============================] - 10s 174us/step - loss: 157.2397
Epoch 11/50
60000/60000 [==============================] - 10s 168us/step - loss: 156.5077
Epoch 12/50
60000/60000 [==============================] - 10s 163us/step - loss: 155.8691
Epoch 13/50
60000/60000 [==============================] - 11s 181us/step - loss: 155.3454
Epoch 14/50
60000/60000 [==============================] - 10s 171us/step - loss: 154.8616
Epoch 15/50
60000/60000 [==============================] - 10s 165us/step - loss: 154.4592
Epoch 16/50
60000/60000 [==============================] - 10s 172us/step - loss: 154.0723
Epoch 17/50
60000/60000 [==============================] - 10s 174us/step - loss: 153.7450
Epoch 18/50
60000/60000 [==============================] - 10s 169us/step - loss: 153.4252
Epoch 19/50
60000/60000 [==============================] - 11s 175us/step - loss: 153.1287
Epoch 20/50
60000/60000 [==============================] - 10s 170us/step - loss: 152.8623
Epoch 21/50
60000/60000 [==============================] - 11s 181us/step - loss: 152.6080
Epoch 22/50
60000/60000 [==============================] - 11s 177us/step - loss: 152.3954
Epoch 23/50
60000/60000 [==============================] - 11s 179us/step - loss: 152.1744
Epoch 24/50
60000/60000 [==============================] - 10s 171us/step - loss: 151.9752
Epoch 25/50
60000/60000 [==============================] - 10s 172us/step - loss: 151.7443
Epoch 26/50
60000/60000 [==============================] - 11s 176us/step - loss: 151.5782
Epoch 27/50
60000/60000 [==============================] - 10s 169us/step - loss: 151.3990
Epoch 28/50
60000/60000 [==============================] - 10s 168us/step - loss: 151.2458
Epoch 29/50
60000/60000 [==============================] - 10s 168us/step - loss: 151.0896
Epoch 30/50
60000/60000 [==============================] - 10s 167us/step - loss: 150.9448
Epoch 31/50
60000/60000 [==============================] - 10s 172us/step - loss: 150.7903
Epoch 32/50
60000/60000 [==============================] - 10s 170us/step - loss: 150.6565
Epoch 33/50
60000/60000 [==============================] - 11s 177us/step - loss: 150.5447
Epoch 34/50
60000/60000 [==============================] - 10s 166us/step - loss: 150.3798
Epoch 35/50
60000/60000 [==============================] - 10s 168us/step - loss: 150.2577
Epoch 36/50
60000/60000 [==============================] - 10s 174us/step - loss: 150.1473
Epoch 37/50
60000/60000 [==============================] - 10s 172us/step - loss: 150.0306
Epoch 38/50
60000/60000 [==============================] - 11s 180us/step - loss: 149.9003
Epoch 39/50
60000/60000 [==============================] - 10s 174us/step - loss: 149.7781
Epoch 40/50
60000/60000 [==============================] - 10s 169us/step - loss: 149.6790
Epoch 41/50
60000/60000 [==============================] - 10s 168us/step - loss: 149.5762
Epoch 42/50
60000/60000 [==============================] - 10s 167us/step - loss: 149.4664
Epoch 43/50
60000/60000 [==============================] - 10s 174us/step - loss: 149.3887
Epoch 44/50
60000/60000 [==============================] - 10s 171us/step - loss: 149.2608
Epoch 45/50
60000/60000 [==============================] - 10s 168us/step - loss: 149.1942
Epoch 46/50
60000/60000 [==============================] - 10s 170us/step - loss: 149.0778
Epoch 47/50
60000/60000 [==============================] - 10s 167us/step - loss: 148.9875
Epoch 48/50
60000/60000 [==============================] - 11s 176us/step - loss: 148.8957
Epoch 49/50
60000/60000 [==============================] - 10s 170us/step - loss: 148.8262
Epoch 50/50
60000/60000 [==============================] - 10s 170us/step - loss: 148.7420
In [30]:
fig, ax = plt.subplots(figsize=golden_figsize(6))

for mc_sample_size, history in zip(mc_sample_sizes, histories1):

    plot_fit_history(history, 
                     batch_size=100, 
                     mc_sample_size=mc_sample_size, 
                     ax=ax)

plt.show()
In [31]:
histories2 = list(map(partial(fit_history, 65), mc_sample_sizes))
batch size 65 | MC sample size 1
Epoch 1/50
60000/60000 [==============================] - 3s 55us/step - loss: 184.4563
Epoch 2/50
60000/60000 [==============================] - 3s 51us/step - loss: 168.0012
Epoch 3/50
60000/60000 [==============================] - 3s 44us/step - loss: 165.0855
Epoch 4/50
60000/60000 [==============================] - 3s 42us/step - loss: 163.2620
Epoch 5/50
60000/60000 [==============================] - 3s 45us/step - loss: 161.7270
Epoch 6/50
60000/60000 [==============================] - 3s 43us/step - loss: 160.3433
Epoch 7/50
60000/60000 [==============================] - 3s 47us/step - loss: 159.1166
Epoch 8/50
60000/60000 [==============================] - 3s 57us/step - loss: 158.0708
Epoch 9/50
60000/60000 [==============================] - 3s 42us/step - loss: 157.2468
Epoch 10/50
60000/60000 [==============================] - 2s 42us/step - loss: 156.5685
Epoch 11/50
60000/60000 [==============================] - 3s 42us/step - loss: 155.9980
Epoch 12/50
60000/60000 [==============================] - 3s 42us/step - loss: 155.4974
Epoch 13/50
60000/60000 [==============================] - 3s 42us/step - loss: 155.0259
Epoch 14/50
60000/60000 [==============================] - 3s 50us/step - loss: 154.6323
Epoch 15/50
60000/60000 [==============================] - 3s 48us/step - loss: 154.2812
Epoch 16/50
60000/60000 [==============================] - 3s 46us/step - loss: 153.9239
Epoch 17/50
60000/60000 [==============================] - 3s 46us/step - loss: 153.6179
Epoch 18/50
60000/60000 [==============================] - 3s 49us/step - loss: 153.3159
Epoch 19/50
60000/60000 [==============================] - 3s 47us/step - loss: 153.0734
Epoch 20/50
60000/60000 [==============================] - 3s 43us/step - loss: 152.8139
Epoch 21/50
60000/60000 [==============================] - 3s 42us/step - loss: 152.6091
Epoch 22/50
60000/60000 [==============================] - 2s 42us/step - loss: 152.4014
Epoch 23/50
60000/60000 [==============================] - 2s 41us/step - loss: 152.2169
Epoch 24/50
60000/60000 [==============================] - 2s 40us/step - loss: 152.0270
Epoch 25/50
60000/60000 [==============================] - 3s 43us/step - loss: 151.8327
Epoch 26/50
60000/60000 [==============================] - 3s 43us/step - loss: 151.7105
Epoch 27/50
60000/60000 [==============================] - 3s 50us/step - loss: 151.5365
Epoch 28/50
60000/60000 [==============================] - 3s 42us/step - loss: 151.3909
Epoch 29/50
60000/60000 [==============================] - 3s 47us/step - loss: 151.2554
Epoch 30/50
60000/60000 [==============================] - 3s 44us/step - loss: 151.1200
Epoch 31/50
60000/60000 [==============================] - 3s 43us/step - loss: 150.9801
Epoch 32/50
60000/60000 [==============================] - 3s 50us/step - loss: 150.8682
Epoch 33/50
60000/60000 [==============================] - 3s 44us/step - loss: 150.7814
Epoch 34/50
60000/60000 [==============================] - 3s 46us/step - loss: 150.6491
Epoch 35/50
60000/60000 [==============================] - 3s 52us/step - loss: 150.5359
Epoch 36/50
60000/60000 [==============================] - 3s 57us/step - loss: 150.4560
Epoch 37/50
60000/60000 [==============================] - 3s 52us/step - loss: 150.3336
Epoch 38/50
60000/60000 [==============================] - 3s 49us/step - loss: 150.2309
Epoch 39/50
60000/60000 [==============================] - 3s 47us/step - loss: 150.1449
Epoch 40/50
60000/60000 [==============================] - 3s 46us/step - loss: 150.0229
Epoch 41/50
60000/60000 [==============================] - 3s 45us/step - loss: 149.9789
Epoch 42/50
60000/60000 [==============================] - 3s 45us/step - loss: 149.8242
Epoch 43/50
60000/60000 [==============================] - 3s 49us/step - loss: 149.7797
Epoch 44/50
60000/60000 [==============================] - 3s 52us/step - loss: 149.6702
Epoch 45/50
60000/60000 [==============================] - 3s 47us/step - loss: 149.5554
Epoch 46/50
60000/60000 [==============================] - 3s 48us/step - loss: 149.4697
Epoch 47/50
60000/60000 [==============================] - 3s 56us/step - loss: 149.3671
Epoch 48/50
60000/60000 [==============================] - 3s 46us/step - loss: 149.2490
Epoch 49/50
60000/60000 [==============================] - 2s 39us/step - loss: 149.2043
Epoch 50/50
60000/60000 [==============================] - 2s 39us/step - loss: 149.1012
batch size 65 | MC sample size 5
Epoch 1/50
60000/60000 [==============================] - 4s 68us/step - loss: 184.9790
Epoch 2/50
60000/60000 [==============================] - 4s 59us/step - loss: 168.8429
Epoch 3/50
60000/60000 [==============================] - 4s 59us/step - loss: 164.4541
Epoch 4/50
60000/60000 [==============================] - 3s 58us/step - loss: 161.9596
Epoch 5/50
60000/60000 [==============================] - 4s 63us/step - loss: 160.1261
Epoch 6/50
60000/60000 [==============================] - 4s 62us/step - loss: 158.7714
Epoch 7/50
60000/60000 [==============================] - 4s 61us/step - loss: 157.7245
Epoch 8/50
60000/60000 [==============================] - 3s 57us/step - loss: 156.9012
Epoch 9/50
60000/60000 [==============================] - 3s 57us/step - loss: 156.2454
Epoch 10/50
60000/60000 [==============================] - 3s 57us/step - loss: 155.6829
Epoch 11/50
60000/60000 [==============================] - 3s 57us/step - loss: 155.1900
Epoch 12/50
60000/60000 [==============================] - 3s 58us/step - loss: 154.7334
Epoch 13/50
60000/60000 [==============================] - 3s 57us/step - loss: 154.3399
Epoch 14/50
60000/60000 [==============================] - 3s 57us/step - loss: 153.9612
Epoch 15/50
60000/60000 [==============================] - 3s 57us/step - loss: 153.6554
Epoch 16/50
60000/60000 [==============================] - 3s 57us/step - loss: 153.3446
Epoch 17/50
60000/60000 [==============================] - 4s 62us/step - loss: 153.0813
Epoch 18/50
60000/60000 [==============================] - 4s 64us/step - loss: 152.7987
Epoch 19/50
60000/60000 [==============================] - 4s 68us/step - loss: 152.5490
Epoch 20/50
60000/60000 [==============================] - 4s 66us/step - loss: 152.3222
Epoch 21/50
60000/60000 [==============================] - 4s 64us/step - loss: 152.0873
Epoch 22/50
60000/60000 [==============================] - 4s 64us/step - loss: 151.8929
Epoch 23/50
60000/60000 [==============================] - 4s 64us/step - loss: 151.6653
Epoch 24/50
60000/60000 [==============================] - 4s 65us/step - loss: 151.4926
Epoch 25/50
60000/60000 [==============================] - 4s 64us/step - loss: 151.3432
Epoch 26/50
60000/60000 [==============================] - 4s 65us/step - loss: 151.1544
Epoch 27/50
60000/60000 [==============================] - 4s 66us/step - loss: 150.9863
Epoch 28/50
60000/60000 [==============================] - 5s 78us/step - loss: 150.8234
Epoch 29/50
60000/60000 [==============================] - 5s 76us/step - loss: 150.7162
Epoch 30/50
60000/60000 [==============================] - 5s 78us/step - loss: 150.5368
Epoch 31/50
60000/60000 [==============================] - 5s 81us/step - loss: 150.4092
Epoch 32/50
60000/60000 [==============================] - 5s 77us/step - loss: 150.2617
Epoch 33/50
60000/60000 [==============================] - 4s 74us/step - loss: 150.1359
Epoch 34/50
60000/60000 [==============================] - 4s 75us/step - loss: 150.0313
Epoch 35/50
60000/60000 [==============================] - 4s 73us/step - loss: 149.9199
Epoch 36/50
60000/60000 [==============================] - 4s 73us/step - loss: 149.8035
Epoch 37/50
60000/60000 [==============================] - 5s 78us/step - loss: 149.6750
Epoch 38/50
60000/60000 [==============================] - 5s 81us/step - loss: 149.5902
Epoch 39/50
60000/60000 [==============================] - 4s 73us/step - loss: 149.4580
Epoch 40/50
60000/60000 [==============================] - 4s 68us/step - loss: 149.3713
Epoch 41/50
60000/60000 [==============================] - 4s 69us/step - loss: 149.2645
Epoch 42/50
60000/60000 [==============================] - 4s 64us/step - loss: 149.1763
Epoch 43/50
60000/60000 [==============================] - 4s 63us/step - loss: 149.0803
Epoch 44/50
60000/60000 [==============================] - 4s 63us/step - loss: 148.9728
Epoch 45/50
60000/60000 [==============================] - 4s 64us/step - loss: 148.8995
Epoch 46/50
60000/60000 [==============================] - 4s 63us/step - loss: 148.8191
Epoch 47/50
60000/60000 [==============================] - 4s 63us/step - loss: 148.7228
Epoch 48/50
60000/60000 [==============================] - 4s 65us/step - loss: 148.6509
Epoch 49/50
60000/60000 [==============================] - 4s 64us/step - loss: 148.5672
Epoch 50/50
60000/60000 [==============================] - 4s 64us/step - loss: 148.5031
batch size 65 | MC sample size 25
Epoch 1/50
60000/60000 [==============================] - 12s 196us/step - loss: 185.0923
Epoch 2/50
60000/60000 [==============================] - 11s 186us/step - loss: 167.0106
Epoch 3/50
60000/60000 [==============================] - 11s 187us/step - loss: 163.4489
Epoch 4/50
60000/60000 [==============================] - 11s 186us/step - loss: 161.7727
Epoch 5/50
60000/60000 [==============================] - 11s 186us/step - loss: 160.6129
Epoch 6/50
60000/60000 [==============================] - 11s 190us/step - loss: 159.6776
Epoch 7/50
60000/60000 [==============================] - 11s 179us/step - loss: 158.9042
Epoch 8/50
60000/60000 [==============================] - 11s 183us/step - loss: 158.2042
Epoch 9/50
60000/60000 [==============================] - 11s 190us/step - loss: 157.5665
Epoch 10/50
60000/60000 [==============================] - 11s 190us/step - loss: 156.9723
Epoch 11/50
60000/60000 [==============================] - 11s 184us/step - loss: 156.5029
Epoch 12/50
60000/60000 [==============================] - 11s 180us/step - loss: 156.0532
Epoch 13/50
60000/60000 [==============================] - 11s 190us/step - loss: 155.6011
Epoch 14/50
60000/60000 [==============================] - 11s 185us/step - loss: 155.2406
Epoch 15/50
60000/60000 [==============================] - 11s 179us/step - loss: 154.9010
Epoch 16/50
60000/60000 [==============================] - 11s 182us/step - loss: 154.5755
Epoch 17/50
60000/60000 [==============================] - 11s 180us/step - loss: 154.3036
Epoch 18/50
60000/60000 [==============================] - 11s 187us/step - loss: 154.0090
Epoch 19/50
60000/60000 [==============================] - 11s 183us/step - loss: 153.7666
Epoch 20/50
60000/60000 [==============================] - 11s 179us/step - loss: 153.5260
Epoch 21/50
60000/60000 [==============================] - 11s 178us/step - loss: 153.3229
Epoch 22/50
60000/60000 [==============================] - 11s 182us/step - loss: 153.0734
Epoch 23/50
60000/60000 [==============================] - 11s 181us/step - loss: 152.8777
Epoch 24/50
60000/60000 [==============================] - 11s 184us/step - loss: 152.7168
Epoch 25/50
60000/60000 [==============================] - 11s 181us/step - loss: 152.5363
Epoch 26/50
60000/60000 [==============================] - 11s 181us/step - loss: 152.3606
Epoch 27/50
60000/60000 [==============================] - 11s 178us/step - loss: 152.1795
Epoch 28/50
60000/60000 [==============================] - 11s 181us/step - loss: 152.0166
Epoch 29/50
60000/60000 [==============================] - 11s 183us/step - loss: 151.8719
Epoch 30/50
60000/60000 [==============================] - 11s 183us/step - loss: 151.7429
Epoch 31/50
60000/60000 [==============================] - 11s 180us/step - loss: 151.6034
Epoch 32/50
60000/60000 [==============================] - 11s 180us/step - loss: 151.4665
Epoch 33/50
60000/60000 [==============================] - 11s 180us/step - loss: 151.3000
Epoch 34/50
60000/60000 [==============================] - 11s 181us/step - loss: 151.1969
Epoch 35/50
60000/60000 [==============================] - 11s 180us/step - loss: 151.0867
Epoch 36/50
60000/60000 [==============================] - 11s 180us/step - loss: 150.9416
Epoch 37/50
60000/60000 [==============================] - 11s 181us/step - loss: 150.8220
Epoch 38/50
60000/60000 [==============================] - 11s 185us/step - loss: 150.7098
Epoch 39/50
60000/60000 [==============================] - 11s 183us/step - loss: 150.5890
Epoch 40/50
60000/60000 [==============================] - 11s 180us/step - loss: 150.4597
Epoch 41/50
60000/60000 [==============================] - 11s 186us/step - loss: 150.3663
Epoch 42/50
60000/60000 [==============================] - 11s 182us/step - loss: 150.2396
Epoch 43/50
60000/60000 [==============================] - 11s 182us/step - loss: 150.1371
Epoch 44/50
60000/60000 [==============================] - 11s 189us/step - loss: 150.0284
Epoch 45/50
60000/60000 [==============================] - 12s 194us/step - loss: 149.9225
Epoch 46/50
60000/60000 [==============================] - 12s 195us/step - loss: 149.8169
Epoch 47/50
60000/60000 [==============================] - 12s 201us/step - loss: 149.7349
Epoch 48/50
60000/60000 [==============================] - 12s 202us/step - loss: 149.6022
Epoch 49/50
60000/60000 [==============================] - 12s 201us/step - loss: 149.5442
Epoch 50/50
60000/60000 [==============================] - 12s 200us/step - loss: 149.4351
In [35]:
fig, ax = plt.subplots(figsize=golden_figsize(6))

for mc_sample_size, history in zip(mc_sample_sizes, histories2):

    plot_fit_history(history, 
                     batch_size=65, 
                     mc_sample_size=mc_sample_size, 
                     ax=ax)

plt.show()
In [36]:
histories3 = list(map(partial(fit_history, 30), mc_sample_sizes))
batch size 30 | MC sample size 1
Epoch 1/50
60000/60000 [==============================] - 7s 124us/step - loss: 177.2299
Epoch 2/50
60000/60000 [==============================] - 6s 105us/step - loss: 165.9526
Epoch 3/50
60000/60000 [==============================] - 6s 101us/step - loss: 163.3389
Epoch 4/50
60000/60000 [==============================] - 6s 102us/step - loss: 161.5406
Epoch 5/50
60000/60000 [==============================] - 6s 103us/step - loss: 159.9758
Epoch 6/50
60000/60000 [==============================] - 7s 109us/step - loss: 158.7068
Epoch 7/50
60000/60000 [==============================] - 6s 100us/step - loss: 157.7793
Epoch 8/50
60000/60000 [==============================] - 6s 104us/step - loss: 156.9706
Epoch 9/50
60000/60000 [==============================] - 6s 99us/step - loss: 156.3414
Epoch 10/50
60000/60000 [==============================] - 5s 89us/step - loss: 155.7998
Epoch 11/50
60000/60000 [==============================] - 6s 93us/step - loss: 155.3102
Epoch 12/50
60000/60000 [==============================] - 5s 81us/step - loss: 154.8747
Epoch 13/50
60000/60000 [==============================] - 5s 79us/step - loss: 154.5256
Epoch 14/50
60000/60000 [==============================] - 5s 83us/step - loss: 154.1773
Epoch 15/50
60000/60000 [==============================] - 5s 89us/step - loss: 153.9464
Epoch 16/50
60000/60000 [==============================] - 5s 83us/step - loss: 153.6999
Epoch 17/50
60000/60000 [==============================] - 5s 85us/step - loss: 153.4632
Epoch 18/50
60000/60000 [==============================] - 5s 89us/step - loss: 153.2248
Epoch 19/50
60000/60000 [==============================] - 5s 85us/step - loss: 153.1149
Epoch 20/50
60000/60000 [==============================] - 5s 86us/step - loss: 152.9554
Epoch 21/50
60000/60000 [==============================] - 5s 86us/step - loss: 152.8291
Epoch 22/50
60000/60000 [==============================] - 5s 86us/step - loss: 152.6762
Epoch 23/50
60000/60000 [==============================] - 5s 89us/step - loss: 152.5805
Epoch 24/50
60000/60000 [==============================] - 5s 90us/step - loss: 152.4390
Epoch 25/50
60000/60000 [==============================] - 5s 87us/step - loss: 152.3914
Epoch 26/50
60000/60000 [==============================] - 5s 82us/step - loss: 152.2789
Epoch 27/50
60000/60000 [==============================] - 5s 86us/step - loss: 152.2040
Epoch 28/50
60000/60000 [==============================] - 5s 82us/step - loss: 152.0841
Epoch 29/50
60000/60000 [==============================] - 5s 85us/step - loss: 152.0246
Epoch 30/50
60000/60000 [==============================] - 5s 84us/step - loss: 151.9551
Epoch 31/50
60000/60000 [==============================] - 5s 83us/step - loss: 151.9226
Epoch 32/50
60000/60000 [==============================] - 5s 84us/step - loss: 151.8533
Epoch 33/50
60000/60000 [==============================] - 5s 82us/step - loss: 151.8018
Epoch 34/50
60000/60000 [==============================] - 5s 82us/step - loss: 151.7092
Epoch 35/50
60000/60000 [==============================] - 5s 82us/step - loss: 151.6607
Epoch 36/50
60000/60000 [==============================] - 5s 82us/step - loss: 151.6110
Epoch 37/50
60000/60000 [==============================] - 5s 85us/step - loss: 151.5799
Epoch 38/50
60000/60000 [==============================] - 5s 86us/step - loss: 151.5609
Epoch 39/50
60000/60000 [==============================] - 6s 101us/step - loss: 151.5157
Epoch 40/50
60000/60000 [==============================] - 5s 90us/step - loss: 151.4409
Epoch 41/50
60000/60000 [==============================] - 5s 90us/step - loss: 151.4742
Epoch 42/50
60000/60000 [==============================] - 5s 89us/step - loss: 151.4231
Epoch 43/50
60000/60000 [==============================] - 5s 87us/step - loss: 151.3505
Epoch 44/50
60000/60000 [==============================] - 5s 91us/step - loss: 151.3848
Epoch 45/50
60000/60000 [==============================] - 5s 84us/step - loss: 151.3290
Epoch 46/50
60000/60000 [==============================] - 6s 92us/step - loss: 151.3031
Epoch 47/50
60000/60000 [==============================] - 5s 83us/step - loss: 151.2973
Epoch 48/50
60000/60000 [==============================] - 5s 80us/step - loss: 151.2716
Epoch 49/50
60000/60000 [==============================] - 5s 82us/step - loss: 151.1704
Epoch 50/50
60000/60000 [==============================] - 5s 85us/step - loss: 151.1809
batch size 30 | MC sample size 5
Epoch 1/50
60000/60000 [==============================] - 7s 115us/step - loss: 177.1915
Epoch 2/50
60000/60000 [==============================] - 6s 98us/step - loss: 165.5445
Epoch 3/50
60000/60000 [==============================] - 6s 97us/step - loss: 162.5204
Epoch 4/50
60000/60000 [==============================] - 6s 105us/step - loss: 160.3745
Epoch 5/50
60000/60000 [==============================] - 6s 95us/step - loss: 158.7253
Epoch 6/50
60000/60000 [==============================] - 5s 88us/step - loss: 157.3812
Epoch 7/50
60000/60000 [==============================] - 5s 87us/step - loss: 156.4264
Epoch 8/50
60000/60000 [==============================] - 5s 87us/step - loss: 155.6122
Epoch 9/50
60000/60000 [==============================] - 5s 87us/step - loss: 154.9788
Epoch 10/50
60000/60000 [==============================] - 8s 133us/step - loss: 154.4103
Epoch 11/50
60000/60000 [==============================] - 7s 124us/step - loss: 153.9329
Epoch 12/50
60000/60000 [==============================] - 7s 124us/step - loss: 153.5233
Epoch 13/50
60000/60000 [==============================] - 7s 119us/step - loss: 153.1616
Epoch 14/50
60000/60000 [==============================] - 5s 90us/step - loss: 152.8636
Epoch 15/50
60000/60000 [==============================] - 5s 87us/step - loss: 152.5635
Epoch 16/50
60000/60000 [==============================] - 5s 87us/step - loss: 152.3091
Epoch 17/50
60000/60000 [==============================] - 5s 87us/step - loss: 152.0695
Epoch 18/50
60000/60000 [==============================] - 5s 87us/step - loss: 151.8749
Epoch 19/50
60000/60000 [==============================] - 6s 96us/step - loss: 151.6810
Epoch 20/50
60000/60000 [==============================] - 6s 102us/step - loss: 151.4854
Epoch 21/50
60000/60000 [==============================] - 7s 109us/step - loss: 151.3427
Epoch 22/50
60000/60000 [==============================] - 7s 115us/step - loss: 151.1688
Epoch 23/50
60000/60000 [==============================] - 7s 122us/step - loss: 151.0231
Epoch 24/50
60000/60000 [==============================] - 9s 149us/step - loss: 150.8813
Epoch 25/50
60000/60000 [==============================] - 9s 148us/step - loss: 150.7379
Epoch 26/50
60000/60000 [==============================] - 8s 131us/step - loss: 150.6284
Epoch 27/50
60000/60000 [==============================] - 8s 132us/step - loss: 150.5099
Epoch 28/50
60000/60000 [==============================] - 8s 139us/step - loss: 150.3781
Epoch 29/50
60000/60000 [==============================] - 7s 123us/step - loss: 150.2626
Epoch 30/50
60000/60000 [==============================] - 9s 145us/step - loss: 150.1427
Epoch 31/50
60000/60000 [==============================] - 7s 125us/step - loss: 150.0507
Epoch 32/50
60000/60000 [==============================] - 7s 109us/step - loss: 149.9647
Epoch 33/50
60000/60000 [==============================] - 7s 109us/step - loss: 149.8467
Epoch 34/50
60000/60000 [==============================] - 6s 107us/step - loss: 149.7648
Epoch 35/50
60000/60000 [==============================] - 6s 108us/step - loss: 149.6857
Epoch 36/50
60000/60000 [==============================] - 6s 107us/step - loss: 149.5898
Epoch 37/50
60000/60000 [==============================] - 7s 109us/step - loss: 149.5215
Epoch 38/50
60000/60000 [==============================] - 7s 113us/step - loss: 149.4364
Epoch 39/50
60000/60000 [==============================] - 6s 108us/step - loss: 149.3712
Epoch 40/50
60000/60000 [==============================] - 6s 105us/step - loss: 149.2857
Epoch 41/50
60000/60000 [==============================] - 7s 110us/step - loss: 149.2206
Epoch 42/50
60000/60000 [==============================] - 6s 103us/step - loss: 149.1377
Epoch 43/50
60000/60000 [==============================] - 6s 105us/step - loss: 149.0938
Epoch 44/50
60000/60000 [==============================] - 6s 104us/step - loss: 149.0367
Epoch 45/50
60000/60000 [==============================] - 6s 107us/step - loss: 148.9448
Epoch 46/50
60000/60000 [==============================] - 7s 118us/step - loss: 148.9190
Epoch 47/50
60000/60000 [==============================] - 7s 113us/step - loss: 148.8529
Epoch 48/50
60000/60000 [==============================] - 7s 112us/step - loss: 148.7574
Epoch 49/50
60000/60000 [==============================] - 8s 134us/step - loss: 148.7258
Epoch 50/50
60000/60000 [==============================] - 7s 125us/step - loss: 148.6851
batch size 30 | MC sample size 25
Epoch 1/50
60000/60000 [==============================] - 16s 265us/step - loss: 177.4136
Epoch 2/50
60000/60000 [==============================] - 15s 246us/step - loss: 165.5013
Epoch 3/50
60000/60000 [==============================] - 15s 250us/step - loss: 162.5124
Epoch 4/50
60000/60000 [==============================] - 16s 258us/step - loss: 160.2793
Epoch 5/50
60000/60000 [==============================] - 15s 249us/step - loss: 158.4845
Epoch 6/50
60000/60000 [==============================] - 15s 251us/step - loss: 157.2233
Epoch 7/50
60000/60000 [==============================] - 14s 227us/step - loss: 156.2673
Epoch 8/50
60000/60000 [==============================] - 14s 238us/step - loss: 155.4734
Epoch 9/50
60000/60000 [==============================] - 16s 263us/step - loss: 154.8175
Epoch 10/50
60000/60000 [==============================] - 17s 287us/step - loss: 154.2241
Epoch 11/50
60000/60000 [==============================] - 16s 264us/step - loss: 153.7466
Epoch 12/50
60000/60000 [==============================] - 16s 266us/step - loss: 153.3271
Epoch 13/50
60000/60000 [==============================] - 15s 253us/step - loss: 152.9860
Epoch 14/50
60000/60000 [==============================] - 15s 248us/step - loss: 152.6864
Epoch 15/50
60000/60000 [==============================] - 16s 259us/step - loss: 152.3977
Epoch 16/50
60000/60000 [==============================] - 16s 265us/step - loss: 152.1612
Epoch 17/50
60000/60000 [==============================] - 16s 270us/step - loss: 151.9416
Epoch 18/50
60000/60000 [==============================] - 16s 263us/step - loss: 151.7231
Epoch 19/50
60000/60000 [==============================] - 16s 260us/step - loss: 151.5223
Epoch 20/50
60000/60000 [==============================] - 17s 281us/step - loss: 151.3646
Epoch 21/50
60000/60000 [==============================] - 16s 274us/step - loss: 151.2097
Epoch 22/50
60000/60000 [==============================] - 16s 274us/step - loss: 151.0664
Epoch 23/50
60000/60000 [==============================] - 16s 266us/step - loss: 150.9273
Epoch 24/50
60000/60000 [==============================] - 15s 249us/step - loss: 150.8110
Epoch 25/50
60000/60000 [==============================] - 16s 259us/step - loss: 150.6763
Epoch 26/50
60000/60000 [==============================] - 15s 245us/step - loss: 150.5467
Epoch 27/50
60000/60000 [==============================] - 14s 237us/step - loss: 150.4399
Epoch 28/50
60000/60000 [==============================] - 13s 223us/step - loss: 150.3244
Epoch 29/50
60000/60000 [==============================] - 13s 220us/step - loss: 150.2228
Epoch 30/50
60000/60000 [==============================] - 13s 220us/step - loss: 150.1280
Epoch 31/50
60000/60000 [==============================] - 13s 218us/step - loss: 150.0228
Epoch 32/50
60000/60000 [==============================] - 13s 214us/step - loss: 149.9021
Epoch 33/50
60000/60000 [==============================] - 13s 214us/step - loss: 149.8374
Epoch 34/50
60000/60000 [==============================] - 14s 241us/step - loss: 149.7349
Epoch 35/50
60000/60000 [==============================] - 14s 231us/step - loss: 149.6855
Epoch 36/50
60000/60000 [==============================] - 16s 267us/step - loss: 149.6180
Epoch 37/50
60000/60000 [==============================] - 16s 269us/step - loss: 149.5229
Epoch 38/50
60000/60000 [==============================] - 14s 240us/step - loss: 149.4495
Epoch 39/50
60000/60000 [==============================] - 16s 264us/step - loss: 149.3910
Epoch 40/50
60000/60000 [==============================] - 15s 245us/step - loss: 149.3203
Epoch 41/50
60000/60000 [==============================] - 14s 239us/step - loss: 149.2329
Epoch 42/50
60000/60000 [==============================] - 14s 237us/step - loss: 149.1760
Epoch 43/50
60000/60000 [==============================] - 16s 259us/step - loss: 149.1111
Epoch 44/50
60000/60000 [==============================] - 16s 274us/step - loss: 149.0437
Epoch 45/50
60000/60000 [==============================] - 17s 278us/step - loss: 148.9833
Epoch 46/50
60000/60000 [==============================] - 17s 278us/step - loss: 148.8811
Epoch 47/50
60000/60000 [==============================] - 17s 277us/step - loss: 148.8191
Epoch 48/50
60000/60000 [==============================] - 14s 240us/step - loss: 148.7496
Epoch 49/50
60000/60000 [==============================] - 15s 254us/step - loss: 148.6737
Epoch 50/50
60000/60000 [==============================] - 15s 257us/step - loss: 148.6208
In [84]:
fig, ax = plt.subplots(figsize=golden_figsize(6))

for mc_sample_size, history in zip(mc_sample_sizes, histories3):

    plot_fit_history(history, 
                     batch_size=30, 
                     mc_sample_size=mc_sample_size, 
                     ax=ax)
    
plt.show()
In [ ]:
histories_all = [
    histories1, 
    histories2, 
    histories3
]
In [68]:
batch_sizes = [100, 65, 30]
In [86]:
fig, axes = plt.subplots(ncols=3, figsize=(12, 3))
fig.tight_layout()

ymin = None
ymax = None

for i, ax in enumerate(axes):

    for mc_sample_size, history in zip(mc_sample_sizes, 
                                       histories_all[i]):

        loss = history.history.get('loss')

        curr_min = np.min(loss)
        curr_max = np.max(loss)

        if ymin is None or curr_min < ymin:
            ymin = curr_min

        if ymax is None or curr_max > ymax:
            ymax = curr_max

        plot_fit_history(history, 
                         batch_size=batch_sizes[i], 
                         mc_sample_size=mc_sample_size, 
                         ax=ax)
    
for ax in axes:
    ax.set_ylim(.99*ymin, 1.01*ymax)
    
plt.savefig('../../images/vae/nelbo_batch_vs_mc_sample_sizes.svg', format='svg')
plt.show()