variational_autoencoder_mc_samples_grid-checkpoint.ipynb (Source)
Preamble¶
In [1]:
%matplotlib notebook
In [2]:
import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm
from keras import backend as K
from keras.layers import (Input, InputLayer, Dense, Lambda, Layer,
Add, Multiply)
from keras.models import Model, Sequential
from keras.datasets import mnist
Using TensorFlow backend.
In [20]:
import pandas as pd
from functools import partial
from matplotlib.ticker import FormatStrFormatter
from keras.utils.vis_utils import model_to_dot, plot_model
from IPython.display import SVG
Notebook Configuration¶
In [4]:
np.set_printoptions(precision=2,
edgeitems=3,
linewidth=80,
suppress=True)
In [5]:
'TensorFlow version: ' + K.tf.__version__
Out[5]:
'TensorFlow version: 1.4.0'
Constant definitions¶
In [6]:
mc_sample_sizes = [1, 5, 25]
original_dim = 784
latent_dim = 2
intermediate_dim = 256
epochs = 50
epsilon_std = 1.0
In [7]:
(x_train, y_train), (x_test, y_test) = mnist.load_data()
x_train = x_train.reshape(-1, original_dim) / 255.
x_test = x_test.reshape(-1, original_dim) / 255.
In [8]:
def nll(y_true, y_pred):
""" Bernoulli negative log likelihood. """
# keras.losses.binary_crossentropy gives the mean
# over the last axis. We require the sum.
return K.sum(K.binary_crossentropy(y_true, y_pred), axis=-1)
In [9]:
class KLDivergenceLayer(Layer):
""" Identity transform layer that adds KL divergence
to the final model loss.
"""
def __init__(self, *args, **kwargs):
self.is_placeholder = True
super(KLDivergenceLayer, self).__init__(*args, **kwargs)
def call(self, inputs):
mu, log_var = inputs
kl_batch = - .5 * K.sum(1 + log_var -
K.square(mu) -
K.exp(log_var), axis=-1)
self.add_loss(K.mean(kl_batch), inputs=inputs)
return inputs
In [10]:
def build_vae(mc_sample_size, original_dim, latent_dim, intermediate_dim):
x = Input(shape=(original_dim,))
h = Dense(intermediate_dim, activation='relu')(x)
z_mu = Dense(latent_dim)(h)
z_log_var = Dense(latent_dim)(h)
z_mu, z_log_var = KLDivergenceLayer()([z_mu, z_log_var])
z_sigma = Lambda(lambda t: K.exp(.5*t))(z_log_var)
eps = Input(tensor=K.random_normal(stddev=epsilon_std,
shape=(K.shape(x)[0],
mc_sample_size,
latent_dim)))
z_eps = Multiply()([z_sigma, eps])
z = Add()([z_mu, z_eps])
decoder = Sequential([
Dense(intermediate_dim, input_dim=latent_dim, activation='relu'),
Dense(original_dim, activation='sigmoid')
])
x_mean = decoder(z)
return Model(inputs=[x, eps], outputs=x_mean)
In [16]:
def fit_history(batch_size, mc_sample_size):
print('batch size {} | MC sample size {}'
.format(batch_size, mc_sample_size))
x_train_target = np.tile(np.expand_dims(x_train, axis=1),
reps=(1, mc_sample_size, 1))
x_test_target = np.tile(np.expand_dims(x_test, axis=1),
reps=(1, mc_sample_size, 1))
vae = build_vae(mc_sample_size, original_dim, latent_dim,
intermediate_dim)
vae.compile(optimizer='rmsprop', loss=nll)
return vae.fit(x_train,
x_train_target,
shuffle=True,
epochs=epochs,
batch_size=batch_size)
In [29]:
def plot_fit_history(h, batch_size, mc_sample_size, ax=None):
if ax is None:
ax = plt.gca()
pd.DataFrame(h.history).plot(ax=ax, y='loss',
label='MC samples: {:d}'.format(mc_sample_size))
ax.set_title('Batch size: {:d}'.format(batch_size))
ax.set_ylabel('NELBO')
ax.set_xlabel('# epochs')
In [18]:
golden_figsize = lambda width: (width, 2. * width / (1 + np.sqrt(5)))
In [25]:
histories1 = list(map(partial(fit_history, 100), mc_sample_sizes))
batch size 100 | MC sample size 1 Epoch 1/50 60000/60000 [==============================] - 2s 40us/step - loss: 189.0387 Epoch 2/50 60000/60000 [==============================] - 2s 32us/step - loss: 169.0337 Epoch 3/50 60000/60000 [==============================] - 2s 31us/step - loss: 165.6830 Epoch 4/50 60000/60000 [==============================] - 2s 31us/step - loss: 163.4543 Epoch 5/50 60000/60000 [==============================] - 2s 31us/step - loss: 161.6128 Epoch 6/50 60000/60000 [==============================] - 2s 32us/step - loss: 160.1299 Epoch 7/50 60000/60000 [==============================] - 2s 32us/step - loss: 158.9801 Epoch 8/50 60000/60000 [==============================] - 2s 32us/step - loss: 158.1048 Epoch 9/50 60000/60000 [==============================] - 2s 32us/step - loss: 157.3637 Epoch 10/50 60000/60000 [==============================] - 2s 32us/step - loss: 156.7864 Epoch 11/50 60000/60000 [==============================] - 2s 32us/step - loss: 156.2620 Epoch 12/50 60000/60000 [==============================] - 2s 31us/step - loss: 155.7970 Epoch 13/50 60000/60000 [==============================] - 2s 32us/step - loss: 155.4058 Epoch 14/50 60000/60000 [==============================] - 2s 30us/step - loss: 155.0614 Epoch 15/50 60000/60000 [==============================] - 2s 31us/step - loss: 154.7134 Epoch 16/50 60000/60000 [==============================] - 2s 31us/step - loss: 154.4399 Epoch 17/50 60000/60000 [==============================] - 2s 31us/step - loss: 154.1532 Epoch 18/50 60000/60000 [==============================] - 2s 31us/step - loss: 153.9101 Epoch 19/50 60000/60000 [==============================] - 2s 31us/step - loss: 153.6731 Epoch 20/50 60000/60000 [==============================] - 2s 31us/step - loss: 153.4678 Epoch 21/50 60000/60000 [==============================] - 2s 31us/step - loss: 153.2793 Epoch 22/50 60000/60000 [==============================] - 2s 31us/step - loss: 153.0618 Epoch 23/50 60000/60000 [==============================] - 2s 31us/step - loss: 152.8981 Epoch 24/50 60000/60000 [==============================] - 2s 31us/step - loss: 152.7158 Epoch 25/50 60000/60000 [==============================] - 2s 31us/step - loss: 152.5932 Epoch 26/50 60000/60000 [==============================] - 2s 31us/step - loss: 152.4110 Epoch 27/50 60000/60000 [==============================] - 2s 31us/step - loss: 152.2653 Epoch 28/50 60000/60000 [==============================] - 2s 31us/step - loss: 152.1090 Epoch 29/50 60000/60000 [==============================] - 2s 31us/step - loss: 151.9918 Epoch 30/50 60000/60000 [==============================] - 2s 31us/step - loss: 151.8823 Epoch 31/50 60000/60000 [==============================] - 2s 31us/step - loss: 151.7416 Epoch 32/50 60000/60000 [==============================] - 2s 31us/step - loss: 151.6229 Epoch 33/50 60000/60000 [==============================] - 2s 31us/step - loss: 151.5112 Epoch 34/50 60000/60000 [==============================] - 2s 31us/step - loss: 151.4135 Epoch 35/50 60000/60000 [==============================] - 2s 30us/step - loss: 151.3030 Epoch 36/50 60000/60000 [==============================] - 2s 30us/step - loss: 151.1725 Epoch 37/50 60000/60000 [==============================] - 2s 30us/step - loss: 151.0907 Epoch 38/50 60000/60000 [==============================] - 2s 30us/step - loss: 150.9281 Epoch 39/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.8562 Epoch 40/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.7535 Epoch 41/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.6494 Epoch 42/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.5700 Epoch 43/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.4739 Epoch 44/50 60000/60000 [==============================] - 2s 32us/step - loss: 150.3871 Epoch 45/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.2817 Epoch 46/50 60000/60000 [==============================] - 2s 31us/step - loss: 150.1927 Epoch 47/50 60000/60000 [==============================] - 2s 30us/step - loss: 150.1212 Epoch 48/50 60000/60000 [==============================] - 2s 30us/step - loss: 149.9925 Epoch 49/50 60000/60000 [==============================] - 2s 30us/step - loss: 149.9569 Epoch 50/50 60000/60000 [==============================] - 2s 30us/step - loss: 149.8732 batch size 100 | MC sample size 5 Epoch 1/50 60000/60000 [==============================] - 4s 62us/step - loss: 190.1918 Epoch 2/50 60000/60000 [==============================] - 3s 53us/step - loss: 169.5446 Epoch 3/50 60000/60000 [==============================] - 3s 54us/step - loss: 166.2662 Epoch 4/50 60000/60000 [==============================] - 3s 53us/step - loss: 164.0870 Epoch 5/50 60000/60000 [==============================] - 3s 54us/step - loss: 162.5137 Epoch 6/50 60000/60000 [==============================] - 3s 53us/step - loss: 161.3143 Epoch 7/50 60000/60000 [==============================] - 3s 53us/step - loss: 160.2275 Epoch 8/50 60000/60000 [==============================] - 3s 53us/step - loss: 159.1967 Epoch 9/50 60000/60000 [==============================] - 3s 53us/step - loss: 158.1983 Epoch 10/50 60000/60000 [==============================] - 3s 53us/step - loss: 157.3319 Epoch 11/50 60000/60000 [==============================] - 3s 54us/step - loss: 156.5584 Epoch 12/50 60000/60000 [==============================] - 3s 54us/step - loss: 155.9154 Epoch 13/50 60000/60000 [==============================] - 3s 54us/step - loss: 155.3372 Epoch 14/50 60000/60000 [==============================] - 3s 54us/step - loss: 154.8555 Epoch 15/50 60000/60000 [==============================] - 3s 54us/step - loss: 154.4196 Epoch 16/50 60000/60000 [==============================] - 3s 53us/step - loss: 154.0410 Epoch 17/50 60000/60000 [==============================] - 3s 55us/step - loss: 153.6830 Epoch 18/50 60000/60000 [==============================] - 3s 55us/step - loss: 153.3692 Epoch 19/50 60000/60000 [==============================] - 3s 55us/step - loss: 153.0748 Epoch 20/50 60000/60000 [==============================] - 3s 55us/step - loss: 152.8088 Epoch 21/50 60000/60000 [==============================] - 3s 54us/step - loss: 152.5538 Epoch 22/50 60000/60000 [==============================] - 3s 53us/step - loss: 152.2983 Epoch 23/50 60000/60000 [==============================] - 3s 53us/step - loss: 152.0915 Epoch 24/50 60000/60000 [==============================] - 3s 53us/step - loss: 151.8647 Epoch 25/50 60000/60000 [==============================] - 3s 54us/step - loss: 151.6694 Epoch 26/50 60000/60000 [==============================] - 3s 53us/step - loss: 151.4926 Epoch 27/50 60000/60000 [==============================] - 3s 53us/step - loss: 151.3275 Epoch 28/50 60000/60000 [==============================] - 3s 53us/step - loss: 151.1350 Epoch 29/50 60000/60000 [==============================] - 3s 53us/step - loss: 151.0158 Epoch 30/50 60000/60000 [==============================] - 3s 53us/step - loss: 150.8242 Epoch 31/50 60000/60000 [==============================] - 3s 53us/step - loss: 150.6936 Epoch 32/50 60000/60000 [==============================] - 3s 54us/step - loss: 150.5376 Epoch 33/50 60000/60000 [==============================] - 3s 54us/step - loss: 150.4095 Epoch 34/50 60000/60000 [==============================] - 3s 53us/step - loss: 150.2751 Epoch 35/50 60000/60000 [==============================] - 3s 53us/step - loss: 150.1579 Epoch 36/50 60000/60000 [==============================] - 3s 53us/step - loss: 150.0314 Epoch 37/50 60000/60000 [==============================] - 3s 54us/step - loss: 149.9030 Epoch 38/50 60000/60000 [==============================] - 3s 53us/step - loss: 149.7720 Epoch 39/50 60000/60000 [==============================] - 3s 54us/step - loss: 149.6809 Epoch 40/50 60000/60000 [==============================] - 3s 53us/step - loss: 149.5872 Epoch 41/50 60000/60000 [==============================] - 3s 53us/step - loss: 149.4470 Epoch 42/50 60000/60000 [==============================] - 3s 53us/step - loss: 149.3811 Epoch 43/50 60000/60000 [==============================] - 3s 54us/step - loss: 149.2661 Epoch 44/50 60000/60000 [==============================] - 3s 53us/step - loss: 149.1528 Epoch 45/50 60000/60000 [==============================] - 3s 53us/step - loss: 149.0910 Epoch 46/50 60000/60000 [==============================] - 3s 53us/step - loss: 148.9832 Epoch 47/50 60000/60000 [==============================] - 3s 54us/step - loss: 148.8849 Epoch 48/50 60000/60000 [==============================] - 3s 55us/step - loss: 148.7902 Epoch 49/50 60000/60000 [==============================] - 3s 53us/step - loss: 148.6930 Epoch 50/50 60000/60000 [==============================] - 3s 53us/step - loss: 148.6199 batch size 100 | MC sample size 25 Epoch 1/50 60000/60000 [==============================] - 11s 176us/step - loss: 190.2488 Epoch 2/50 60000/60000 [==============================] - 10s 170us/step - loss: 169.3187 Epoch 3/50 60000/60000 [==============================] - 10s 169us/step - loss: 166.2625 Epoch 4/50 60000/60000 [==============================] - 10s 169us/step - loss: 164.1589 Epoch 5/50 60000/60000 [==============================] - 10s 170us/step - loss: 162.6340 Epoch 6/50 60000/60000 [==============================] - 10s 167us/step - loss: 161.3680 Epoch 7/50 60000/60000 [==============================] - 10s 169us/step - loss: 160.2168 Epoch 8/50 60000/60000 [==============================] - 10s 169us/step - loss: 159.1401 Epoch 9/50 60000/60000 [==============================] - 10s 173us/step - loss: 158.1189 Epoch 10/50 60000/60000 [==============================] - 10s 174us/step - loss: 157.2397 Epoch 11/50 60000/60000 [==============================] - 10s 168us/step - loss: 156.5077 Epoch 12/50 60000/60000 [==============================] - 10s 163us/step - loss: 155.8691 Epoch 13/50 60000/60000 [==============================] - 11s 181us/step - loss: 155.3454 Epoch 14/50 60000/60000 [==============================] - 10s 171us/step - loss: 154.8616 Epoch 15/50 60000/60000 [==============================] - 10s 165us/step - loss: 154.4592 Epoch 16/50 60000/60000 [==============================] - 10s 172us/step - loss: 154.0723 Epoch 17/50 60000/60000 [==============================] - 10s 174us/step - loss: 153.7450 Epoch 18/50 60000/60000 [==============================] - 10s 169us/step - loss: 153.4252 Epoch 19/50 60000/60000 [==============================] - 11s 175us/step - loss: 153.1287 Epoch 20/50 60000/60000 [==============================] - 10s 170us/step - loss: 152.8623 Epoch 21/50 60000/60000 [==============================] - 11s 181us/step - loss: 152.6080 Epoch 22/50 60000/60000 [==============================] - 11s 177us/step - loss: 152.3954 Epoch 23/50 60000/60000 [==============================] - 11s 179us/step - loss: 152.1744 Epoch 24/50 60000/60000 [==============================] - 10s 171us/step - loss: 151.9752 Epoch 25/50 60000/60000 [==============================] - 10s 172us/step - loss: 151.7443 Epoch 26/50 60000/60000 [==============================] - 11s 176us/step - loss: 151.5782 Epoch 27/50 60000/60000 [==============================] - 10s 169us/step - loss: 151.3990 Epoch 28/50 60000/60000 [==============================] - 10s 168us/step - loss: 151.2458 Epoch 29/50 60000/60000 [==============================] - 10s 168us/step - loss: 151.0896 Epoch 30/50 60000/60000 [==============================] - 10s 167us/step - loss: 150.9448 Epoch 31/50 60000/60000 [==============================] - 10s 172us/step - loss: 150.7903 Epoch 32/50 60000/60000 [==============================] - 10s 170us/step - loss: 150.6565 Epoch 33/50 60000/60000 [==============================] - 11s 177us/step - loss: 150.5447 Epoch 34/50 60000/60000 [==============================] - 10s 166us/step - loss: 150.3798 Epoch 35/50 60000/60000 [==============================] - 10s 168us/step - loss: 150.2577 Epoch 36/50 60000/60000 [==============================] - 10s 174us/step - loss: 150.1473 Epoch 37/50 60000/60000 [==============================] - 10s 172us/step - loss: 150.0306 Epoch 38/50 60000/60000 [==============================] - 11s 180us/step - loss: 149.9003 Epoch 39/50 60000/60000 [==============================] - 10s 174us/step - loss: 149.7781 Epoch 40/50 60000/60000 [==============================] - 10s 169us/step - loss: 149.6790 Epoch 41/50 60000/60000 [==============================] - 10s 168us/step - loss: 149.5762 Epoch 42/50 60000/60000 [==============================] - 10s 167us/step - loss: 149.4664 Epoch 43/50 60000/60000 [==============================] - 10s 174us/step - loss: 149.3887 Epoch 44/50 60000/60000 [==============================] - 10s 171us/step - loss: 149.2608 Epoch 45/50 60000/60000 [==============================] - 10s 168us/step - loss: 149.1942 Epoch 46/50 60000/60000 [==============================] - 10s 170us/step - loss: 149.0778 Epoch 47/50 60000/60000 [==============================] - 10s 167us/step - loss: 148.9875 Epoch 48/50 60000/60000 [==============================] - 11s 176us/step - loss: 148.8957 Epoch 49/50 60000/60000 [==============================] - 10s 170us/step - loss: 148.8262 Epoch 50/50 60000/60000 [==============================] - 10s 170us/step - loss: 148.7420
In [30]:
fig, ax = plt.subplots(figsize=golden_figsize(6))
for mc_sample_size, history in zip(mc_sample_sizes, histories1):
plot_fit_history(history,
batch_size=100,
mc_sample_size=mc_sample_size,
ax=ax)
plt.show()
In [31]:
histories2 = list(map(partial(fit_history, 65), mc_sample_sizes))
batch size 65 | MC sample size 1 Epoch 1/50 60000/60000 [==============================] - 3s 55us/step - loss: 184.4563 Epoch 2/50 60000/60000 [==============================] - 3s 51us/step - loss: 168.0012 Epoch 3/50 60000/60000 [==============================] - 3s 44us/step - loss: 165.0855 Epoch 4/50 60000/60000 [==============================] - 3s 42us/step - loss: 163.2620 Epoch 5/50 60000/60000 [==============================] - 3s 45us/step - loss: 161.7270 Epoch 6/50 60000/60000 [==============================] - 3s 43us/step - loss: 160.3433 Epoch 7/50 60000/60000 [==============================] - 3s 47us/step - loss: 159.1166 Epoch 8/50 60000/60000 [==============================] - 3s 57us/step - loss: 158.0708 Epoch 9/50 60000/60000 [==============================] - 3s 42us/step - loss: 157.2468 Epoch 10/50 60000/60000 [==============================] - 2s 42us/step - loss: 156.5685 Epoch 11/50 60000/60000 [==============================] - 3s 42us/step - loss: 155.9980 Epoch 12/50 60000/60000 [==============================] - 3s 42us/step - loss: 155.4974 Epoch 13/50 60000/60000 [==============================] - 3s 42us/step - loss: 155.0259 Epoch 14/50 60000/60000 [==============================] - 3s 50us/step - loss: 154.6323 Epoch 15/50 60000/60000 [==============================] - 3s 48us/step - loss: 154.2812 Epoch 16/50 60000/60000 [==============================] - 3s 46us/step - loss: 153.9239 Epoch 17/50 60000/60000 [==============================] - 3s 46us/step - loss: 153.6179 Epoch 18/50 60000/60000 [==============================] - 3s 49us/step - loss: 153.3159 Epoch 19/50 60000/60000 [==============================] - 3s 47us/step - loss: 153.0734 Epoch 20/50 60000/60000 [==============================] - 3s 43us/step - loss: 152.8139 Epoch 21/50 60000/60000 [==============================] - 3s 42us/step - loss: 152.6091 Epoch 22/50 60000/60000 [==============================] - 2s 42us/step - loss: 152.4014 Epoch 23/50 60000/60000 [==============================] - 2s 41us/step - loss: 152.2169 Epoch 24/50 60000/60000 [==============================] - 2s 40us/step - loss: 152.0270 Epoch 25/50 60000/60000 [==============================] - 3s 43us/step - loss: 151.8327 Epoch 26/50 60000/60000 [==============================] - 3s 43us/step - loss: 151.7105 Epoch 27/50 60000/60000 [==============================] - 3s 50us/step - loss: 151.5365 Epoch 28/50 60000/60000 [==============================] - 3s 42us/step - loss: 151.3909 Epoch 29/50 60000/60000 [==============================] - 3s 47us/step - loss: 151.2554 Epoch 30/50 60000/60000 [==============================] - 3s 44us/step - loss: 151.1200 Epoch 31/50 60000/60000 [==============================] - 3s 43us/step - loss: 150.9801 Epoch 32/50 60000/60000 [==============================] - 3s 50us/step - loss: 150.8682 Epoch 33/50 60000/60000 [==============================] - 3s 44us/step - loss: 150.7814 Epoch 34/50 60000/60000 [==============================] - 3s 46us/step - loss: 150.6491 Epoch 35/50 60000/60000 [==============================] - 3s 52us/step - loss: 150.5359 Epoch 36/50 60000/60000 [==============================] - 3s 57us/step - loss: 150.4560 Epoch 37/50 60000/60000 [==============================] - 3s 52us/step - loss: 150.3336 Epoch 38/50 60000/60000 [==============================] - 3s 49us/step - loss: 150.2309 Epoch 39/50 60000/60000 [==============================] - 3s 47us/step - loss: 150.1449 Epoch 40/50 60000/60000 [==============================] - 3s 46us/step - loss: 150.0229 Epoch 41/50 60000/60000 [==============================] - 3s 45us/step - loss: 149.9789 Epoch 42/50 60000/60000 [==============================] - 3s 45us/step - loss: 149.8242 Epoch 43/50 60000/60000 [==============================] - 3s 49us/step - loss: 149.7797 Epoch 44/50 60000/60000 [==============================] - 3s 52us/step - loss: 149.6702 Epoch 45/50 60000/60000 [==============================] - 3s 47us/step - loss: 149.5554 Epoch 46/50 60000/60000 [==============================] - 3s 48us/step - loss: 149.4697 Epoch 47/50 60000/60000 [==============================] - 3s 56us/step - loss: 149.3671 Epoch 48/50 60000/60000 [==============================] - 3s 46us/step - loss: 149.2490 Epoch 49/50 60000/60000 [==============================] - 2s 39us/step - loss: 149.2043 Epoch 50/50 60000/60000 [==============================] - 2s 39us/step - loss: 149.1012 batch size 65 | MC sample size 5 Epoch 1/50 60000/60000 [==============================] - 4s 68us/step - loss: 184.9790 Epoch 2/50 60000/60000 [==============================] - 4s 59us/step - loss: 168.8429 Epoch 3/50 60000/60000 [==============================] - 4s 59us/step - loss: 164.4541 Epoch 4/50 60000/60000 [==============================] - 3s 58us/step - loss: 161.9596 Epoch 5/50 60000/60000 [==============================] - 4s 63us/step - loss: 160.1261 Epoch 6/50 60000/60000 [==============================] - 4s 62us/step - loss: 158.7714 Epoch 7/50 60000/60000 [==============================] - 4s 61us/step - loss: 157.7245 Epoch 8/50 60000/60000 [==============================] - 3s 57us/step - loss: 156.9012 Epoch 9/50 60000/60000 [==============================] - 3s 57us/step - loss: 156.2454 Epoch 10/50 60000/60000 [==============================] - 3s 57us/step - loss: 155.6829 Epoch 11/50 60000/60000 [==============================] - 3s 57us/step - loss: 155.1900 Epoch 12/50 60000/60000 [==============================] - 3s 58us/step - loss: 154.7334 Epoch 13/50 60000/60000 [==============================] - 3s 57us/step - loss: 154.3399 Epoch 14/50 60000/60000 [==============================] - 3s 57us/step - loss: 153.9612 Epoch 15/50 60000/60000 [==============================] - 3s 57us/step - loss: 153.6554 Epoch 16/50 60000/60000 [==============================] - 3s 57us/step - loss: 153.3446 Epoch 17/50 60000/60000 [==============================] - 4s 62us/step - loss: 153.0813 Epoch 18/50 60000/60000 [==============================] - 4s 64us/step - loss: 152.7987 Epoch 19/50 60000/60000 [==============================] - 4s 68us/step - loss: 152.5490 Epoch 20/50 60000/60000 [==============================] - 4s 66us/step - loss: 152.3222 Epoch 21/50 60000/60000 [==============================] - 4s 64us/step - loss: 152.0873 Epoch 22/50 60000/60000 [==============================] - 4s 64us/step - loss: 151.8929 Epoch 23/50 60000/60000 [==============================] - 4s 64us/step - loss: 151.6653 Epoch 24/50 60000/60000 [==============================] - 4s 65us/step - loss: 151.4926 Epoch 25/50 60000/60000 [==============================] - 4s 64us/step - loss: 151.3432 Epoch 26/50 60000/60000 [==============================] - 4s 65us/step - loss: 151.1544 Epoch 27/50 60000/60000 [==============================] - 4s 66us/step - loss: 150.9863 Epoch 28/50 60000/60000 [==============================] - 5s 78us/step - loss: 150.8234 Epoch 29/50 60000/60000 [==============================] - 5s 76us/step - loss: 150.7162 Epoch 30/50 60000/60000 [==============================] - 5s 78us/step - loss: 150.5368 Epoch 31/50 60000/60000 [==============================] - 5s 81us/step - loss: 150.4092 Epoch 32/50 60000/60000 [==============================] - 5s 77us/step - loss: 150.2617 Epoch 33/50 60000/60000 [==============================] - 4s 74us/step - loss: 150.1359 Epoch 34/50 60000/60000 [==============================] - 4s 75us/step - loss: 150.0313 Epoch 35/50 60000/60000 [==============================] - 4s 73us/step - loss: 149.9199 Epoch 36/50 60000/60000 [==============================] - 4s 73us/step - loss: 149.8035 Epoch 37/50 60000/60000 [==============================] - 5s 78us/step - loss: 149.6750 Epoch 38/50 60000/60000 [==============================] - 5s 81us/step - loss: 149.5902 Epoch 39/50 60000/60000 [==============================] - 4s 73us/step - loss: 149.4580 Epoch 40/50 60000/60000 [==============================] - 4s 68us/step - loss: 149.3713 Epoch 41/50 60000/60000 [==============================] - 4s 69us/step - loss: 149.2645 Epoch 42/50 60000/60000 [==============================] - 4s 64us/step - loss: 149.1763 Epoch 43/50 60000/60000 [==============================] - 4s 63us/step - loss: 149.0803 Epoch 44/50 60000/60000 [==============================] - 4s 63us/step - loss: 148.9728 Epoch 45/50 60000/60000 [==============================] - 4s 64us/step - loss: 148.8995 Epoch 46/50 60000/60000 [==============================] - 4s 63us/step - loss: 148.8191 Epoch 47/50 60000/60000 [==============================] - 4s 63us/step - loss: 148.7228 Epoch 48/50 60000/60000 [==============================] - 4s 65us/step - loss: 148.6509 Epoch 49/50 60000/60000 [==============================] - 4s 64us/step - loss: 148.5672 Epoch 50/50 60000/60000 [==============================] - 4s 64us/step - loss: 148.5031 batch size 65 | MC sample size 25 Epoch 1/50 60000/60000 [==============================] - 12s 196us/step - loss: 185.0923 Epoch 2/50 60000/60000 [==============================] - 11s 186us/step - loss: 167.0106 Epoch 3/50 60000/60000 [==============================] - 11s 187us/step - loss: 163.4489 Epoch 4/50 60000/60000 [==============================] - 11s 186us/step - loss: 161.7727 Epoch 5/50 60000/60000 [==============================] - 11s 186us/step - loss: 160.6129 Epoch 6/50 60000/60000 [==============================] - 11s 190us/step - loss: 159.6776 Epoch 7/50 60000/60000 [==============================] - 11s 179us/step - loss: 158.9042 Epoch 8/50 60000/60000 [==============================] - 11s 183us/step - loss: 158.2042 Epoch 9/50 60000/60000 [==============================] - 11s 190us/step - loss: 157.5665 Epoch 10/50 60000/60000 [==============================] - 11s 190us/step - loss: 156.9723 Epoch 11/50 60000/60000 [==============================] - 11s 184us/step - loss: 156.5029 Epoch 12/50 60000/60000 [==============================] - 11s 180us/step - loss: 156.0532 Epoch 13/50 60000/60000 [==============================] - 11s 190us/step - loss: 155.6011 Epoch 14/50 60000/60000 [==============================] - 11s 185us/step - loss: 155.2406 Epoch 15/50 60000/60000 [==============================] - 11s 179us/step - loss: 154.9010 Epoch 16/50 60000/60000 [==============================] - 11s 182us/step - loss: 154.5755 Epoch 17/50 60000/60000 [==============================] - 11s 180us/step - loss: 154.3036 Epoch 18/50 60000/60000 [==============================] - 11s 187us/step - loss: 154.0090 Epoch 19/50 60000/60000 [==============================] - 11s 183us/step - loss: 153.7666 Epoch 20/50 60000/60000 [==============================] - 11s 179us/step - loss: 153.5260 Epoch 21/50 60000/60000 [==============================] - 11s 178us/step - loss: 153.3229 Epoch 22/50 60000/60000 [==============================] - 11s 182us/step - loss: 153.0734 Epoch 23/50 60000/60000 [==============================] - 11s 181us/step - loss: 152.8777 Epoch 24/50 60000/60000 [==============================] - 11s 184us/step - loss: 152.7168 Epoch 25/50 60000/60000 [==============================] - 11s 181us/step - loss: 152.5363 Epoch 26/50 60000/60000 [==============================] - 11s 181us/step - loss: 152.3606 Epoch 27/50 60000/60000 [==============================] - 11s 178us/step - loss: 152.1795 Epoch 28/50 60000/60000 [==============================] - 11s 181us/step - loss: 152.0166 Epoch 29/50 60000/60000 [==============================] - 11s 183us/step - loss: 151.8719 Epoch 30/50 60000/60000 [==============================] - 11s 183us/step - loss: 151.7429 Epoch 31/50 60000/60000 [==============================] - 11s 180us/step - loss: 151.6034 Epoch 32/50 60000/60000 [==============================] - 11s 180us/step - loss: 151.4665 Epoch 33/50 60000/60000 [==============================] - 11s 180us/step - loss: 151.3000 Epoch 34/50 60000/60000 [==============================] - 11s 181us/step - loss: 151.1969 Epoch 35/50 60000/60000 [==============================] - 11s 180us/step - loss: 151.0867 Epoch 36/50 60000/60000 [==============================] - 11s 180us/step - loss: 150.9416 Epoch 37/50 60000/60000 [==============================] - 11s 181us/step - loss: 150.8220 Epoch 38/50 60000/60000 [==============================] - 11s 185us/step - loss: 150.7098 Epoch 39/50 60000/60000 [==============================] - 11s 183us/step - loss: 150.5890 Epoch 40/50 60000/60000 [==============================] - 11s 180us/step - loss: 150.4597 Epoch 41/50 60000/60000 [==============================] - 11s 186us/step - loss: 150.3663 Epoch 42/50 60000/60000 [==============================] - 11s 182us/step - loss: 150.2396 Epoch 43/50 60000/60000 [==============================] - 11s 182us/step - loss: 150.1371 Epoch 44/50 60000/60000 [==============================] - 11s 189us/step - loss: 150.0284 Epoch 45/50 60000/60000 [==============================] - 12s 194us/step - loss: 149.9225 Epoch 46/50 60000/60000 [==============================] - 12s 195us/step - loss: 149.8169 Epoch 47/50 60000/60000 [==============================] - 12s 201us/step - loss: 149.7349 Epoch 48/50 60000/60000 [==============================] - 12s 202us/step - loss: 149.6022 Epoch 49/50 60000/60000 [==============================] - 12s 201us/step - loss: 149.5442 Epoch 50/50 60000/60000 [==============================] - 12s 200us/step - loss: 149.4351
In [35]:
fig, ax = plt.subplots(figsize=golden_figsize(6))
for mc_sample_size, history in zip(mc_sample_sizes, histories2):
plot_fit_history(history,
batch_size=65,
mc_sample_size=mc_sample_size,
ax=ax)
plt.show()
In [36]:
histories3 = list(map(partial(fit_history, 30), mc_sample_sizes))
batch size 30 | MC sample size 1 Epoch 1/50 60000/60000 [==============================] - 7s 124us/step - loss: 177.2299 Epoch 2/50 60000/60000 [==============================] - 6s 105us/step - loss: 165.9526 Epoch 3/50 60000/60000 [==============================] - 6s 101us/step - loss: 163.3389 Epoch 4/50 60000/60000 [==============================] - 6s 102us/step - loss: 161.5406 Epoch 5/50 60000/60000 [==============================] - 6s 103us/step - loss: 159.9758 Epoch 6/50 60000/60000 [==============================] - 7s 109us/step - loss: 158.7068 Epoch 7/50 60000/60000 [==============================] - 6s 100us/step - loss: 157.7793 Epoch 8/50 60000/60000 [==============================] - 6s 104us/step - loss: 156.9706 Epoch 9/50 60000/60000 [==============================] - 6s 99us/step - loss: 156.3414 Epoch 10/50 60000/60000 [==============================] - 5s 89us/step - loss: 155.7998 Epoch 11/50 60000/60000 [==============================] - 6s 93us/step - loss: 155.3102 Epoch 12/50 60000/60000 [==============================] - 5s 81us/step - loss: 154.8747 Epoch 13/50 60000/60000 [==============================] - 5s 79us/step - loss: 154.5256 Epoch 14/50 60000/60000 [==============================] - 5s 83us/step - loss: 154.1773 Epoch 15/50 60000/60000 [==============================] - 5s 89us/step - loss: 153.9464 Epoch 16/50 60000/60000 [==============================] - 5s 83us/step - loss: 153.6999 Epoch 17/50 60000/60000 [==============================] - 5s 85us/step - loss: 153.4632 Epoch 18/50 60000/60000 [==============================] - 5s 89us/step - loss: 153.2248 Epoch 19/50 60000/60000 [==============================] - 5s 85us/step - loss: 153.1149 Epoch 20/50 60000/60000 [==============================] - 5s 86us/step - loss: 152.9554 Epoch 21/50 60000/60000 [==============================] - 5s 86us/step - loss: 152.8291 Epoch 22/50 60000/60000 [==============================] - 5s 86us/step - loss: 152.6762 Epoch 23/50 60000/60000 [==============================] - 5s 89us/step - loss: 152.5805 Epoch 24/50 60000/60000 [==============================] - 5s 90us/step - loss: 152.4390 Epoch 25/50 60000/60000 [==============================] - 5s 87us/step - loss: 152.3914 Epoch 26/50 60000/60000 [==============================] - 5s 82us/step - loss: 152.2789 Epoch 27/50 60000/60000 [==============================] - 5s 86us/step - loss: 152.2040 Epoch 28/50 60000/60000 [==============================] - 5s 82us/step - loss: 152.0841 Epoch 29/50 60000/60000 [==============================] - 5s 85us/step - loss: 152.0246 Epoch 30/50 60000/60000 [==============================] - 5s 84us/step - loss: 151.9551 Epoch 31/50 60000/60000 [==============================] - 5s 83us/step - loss: 151.9226 Epoch 32/50 60000/60000 [==============================] - 5s 84us/step - loss: 151.8533 Epoch 33/50 60000/60000 [==============================] - 5s 82us/step - loss: 151.8018 Epoch 34/50 60000/60000 [==============================] - 5s 82us/step - loss: 151.7092 Epoch 35/50 60000/60000 [==============================] - 5s 82us/step - loss: 151.6607 Epoch 36/50 60000/60000 [==============================] - 5s 82us/step - loss: 151.6110 Epoch 37/50 60000/60000 [==============================] - 5s 85us/step - loss: 151.5799 Epoch 38/50 60000/60000 [==============================] - 5s 86us/step - loss: 151.5609 Epoch 39/50 60000/60000 [==============================] - 6s 101us/step - loss: 151.5157 Epoch 40/50 60000/60000 [==============================] - 5s 90us/step - loss: 151.4409 Epoch 41/50 60000/60000 [==============================] - 5s 90us/step - loss: 151.4742 Epoch 42/50 60000/60000 [==============================] - 5s 89us/step - loss: 151.4231 Epoch 43/50 60000/60000 [==============================] - 5s 87us/step - loss: 151.3505 Epoch 44/50 60000/60000 [==============================] - 5s 91us/step - loss: 151.3848 Epoch 45/50 60000/60000 [==============================] - 5s 84us/step - loss: 151.3290 Epoch 46/50 60000/60000 [==============================] - 6s 92us/step - loss: 151.3031 Epoch 47/50 60000/60000 [==============================] - 5s 83us/step - loss: 151.2973 Epoch 48/50 60000/60000 [==============================] - 5s 80us/step - loss: 151.2716 Epoch 49/50 60000/60000 [==============================] - 5s 82us/step - loss: 151.1704 Epoch 50/50 60000/60000 [==============================] - 5s 85us/step - loss: 151.1809 batch size 30 | MC sample size 5 Epoch 1/50 60000/60000 [==============================] - 7s 115us/step - loss: 177.1915 Epoch 2/50 60000/60000 [==============================] - 6s 98us/step - loss: 165.5445 Epoch 3/50 60000/60000 [==============================] - 6s 97us/step - loss: 162.5204 Epoch 4/50 60000/60000 [==============================] - 6s 105us/step - loss: 160.3745 Epoch 5/50 60000/60000 [==============================] - 6s 95us/step - loss: 158.7253 Epoch 6/50 60000/60000 [==============================] - 5s 88us/step - loss: 157.3812 Epoch 7/50 60000/60000 [==============================] - 5s 87us/step - loss: 156.4264 Epoch 8/50 60000/60000 [==============================] - 5s 87us/step - loss: 155.6122 Epoch 9/50 60000/60000 [==============================] - 5s 87us/step - loss: 154.9788 Epoch 10/50 60000/60000 [==============================] - 8s 133us/step - loss: 154.4103 Epoch 11/50 60000/60000 [==============================] - 7s 124us/step - loss: 153.9329 Epoch 12/50 60000/60000 [==============================] - 7s 124us/step - loss: 153.5233 Epoch 13/50 60000/60000 [==============================] - 7s 119us/step - loss: 153.1616 Epoch 14/50 60000/60000 [==============================] - 5s 90us/step - loss: 152.8636 Epoch 15/50 60000/60000 [==============================] - 5s 87us/step - loss: 152.5635 Epoch 16/50 60000/60000 [==============================] - 5s 87us/step - loss: 152.3091 Epoch 17/50 60000/60000 [==============================] - 5s 87us/step - loss: 152.0695 Epoch 18/50 60000/60000 [==============================] - 5s 87us/step - loss: 151.8749 Epoch 19/50 60000/60000 [==============================] - 6s 96us/step - loss: 151.6810 Epoch 20/50 60000/60000 [==============================] - 6s 102us/step - loss: 151.4854 Epoch 21/50 60000/60000 [==============================] - 7s 109us/step - loss: 151.3427 Epoch 22/50 60000/60000 [==============================] - 7s 115us/step - loss: 151.1688 Epoch 23/50 60000/60000 [==============================] - 7s 122us/step - loss: 151.0231 Epoch 24/50 60000/60000 [==============================] - 9s 149us/step - loss: 150.8813 Epoch 25/50 60000/60000 [==============================] - 9s 148us/step - loss: 150.7379 Epoch 26/50 60000/60000 [==============================] - 8s 131us/step - loss: 150.6284 Epoch 27/50 60000/60000 [==============================] - 8s 132us/step - loss: 150.5099 Epoch 28/50 60000/60000 [==============================] - 8s 139us/step - loss: 150.3781 Epoch 29/50 60000/60000 [==============================] - 7s 123us/step - loss: 150.2626 Epoch 30/50 60000/60000 [==============================] - 9s 145us/step - loss: 150.1427 Epoch 31/50 60000/60000 [==============================] - 7s 125us/step - loss: 150.0507 Epoch 32/50 60000/60000 [==============================] - 7s 109us/step - loss: 149.9647 Epoch 33/50 60000/60000 [==============================] - 7s 109us/step - loss: 149.8467 Epoch 34/50 60000/60000 [==============================] - 6s 107us/step - loss: 149.7648 Epoch 35/50 60000/60000 [==============================] - 6s 108us/step - loss: 149.6857 Epoch 36/50 60000/60000 [==============================] - 6s 107us/step - loss: 149.5898 Epoch 37/50 60000/60000 [==============================] - 7s 109us/step - loss: 149.5215 Epoch 38/50 60000/60000 [==============================] - 7s 113us/step - loss: 149.4364 Epoch 39/50 60000/60000 [==============================] - 6s 108us/step - loss: 149.3712 Epoch 40/50 60000/60000 [==============================] - 6s 105us/step - loss: 149.2857 Epoch 41/50 60000/60000 [==============================] - 7s 110us/step - loss: 149.2206 Epoch 42/50 60000/60000 [==============================] - 6s 103us/step - loss: 149.1377 Epoch 43/50 60000/60000 [==============================] - 6s 105us/step - loss: 149.0938 Epoch 44/50 60000/60000 [==============================] - 6s 104us/step - loss: 149.0367 Epoch 45/50 60000/60000 [==============================] - 6s 107us/step - loss: 148.9448 Epoch 46/50 60000/60000 [==============================] - 7s 118us/step - loss: 148.9190 Epoch 47/50 60000/60000 [==============================] - 7s 113us/step - loss: 148.8529 Epoch 48/50 60000/60000 [==============================] - 7s 112us/step - loss: 148.7574 Epoch 49/50 60000/60000 [==============================] - 8s 134us/step - loss: 148.7258 Epoch 50/50 60000/60000 [==============================] - 7s 125us/step - loss: 148.6851 batch size 30 | MC sample size 25 Epoch 1/50 60000/60000 [==============================] - 16s 265us/step - loss: 177.4136 Epoch 2/50 60000/60000 [==============================] - 15s 246us/step - loss: 165.5013 Epoch 3/50 60000/60000 [==============================] - 15s 250us/step - loss: 162.5124 Epoch 4/50 60000/60000 [==============================] - 16s 258us/step - loss: 160.2793 Epoch 5/50 60000/60000 [==============================] - 15s 249us/step - loss: 158.4845 Epoch 6/50 60000/60000 [==============================] - 15s 251us/step - loss: 157.2233 Epoch 7/50 60000/60000 [==============================] - 14s 227us/step - loss: 156.2673 Epoch 8/50 60000/60000 [==============================] - 14s 238us/step - loss: 155.4734 Epoch 9/50 60000/60000 [==============================] - 16s 263us/step - loss: 154.8175 Epoch 10/50 60000/60000 [==============================] - 17s 287us/step - loss: 154.2241 Epoch 11/50 60000/60000 [==============================] - 16s 264us/step - loss: 153.7466 Epoch 12/50 60000/60000 [==============================] - 16s 266us/step - loss: 153.3271 Epoch 13/50 60000/60000 [==============================] - 15s 253us/step - loss: 152.9860 Epoch 14/50 60000/60000 [==============================] - 15s 248us/step - loss: 152.6864 Epoch 15/50 60000/60000 [==============================] - 16s 259us/step - loss: 152.3977 Epoch 16/50 60000/60000 [==============================] - 16s 265us/step - loss: 152.1612 Epoch 17/50 60000/60000 [==============================] - 16s 270us/step - loss: 151.9416 Epoch 18/50 60000/60000 [==============================] - 16s 263us/step - loss: 151.7231 Epoch 19/50 60000/60000 [==============================] - 16s 260us/step - loss: 151.5223 Epoch 20/50 60000/60000 [==============================] - 17s 281us/step - loss: 151.3646 Epoch 21/50 60000/60000 [==============================] - 16s 274us/step - loss: 151.2097 Epoch 22/50 60000/60000 [==============================] - 16s 274us/step - loss: 151.0664 Epoch 23/50 60000/60000 [==============================] - 16s 266us/step - loss: 150.9273 Epoch 24/50 60000/60000 [==============================] - 15s 249us/step - loss: 150.8110 Epoch 25/50 60000/60000 [==============================] - 16s 259us/step - loss: 150.6763 Epoch 26/50 60000/60000 [==============================] - 15s 245us/step - loss: 150.5467 Epoch 27/50 60000/60000 [==============================] - 14s 237us/step - loss: 150.4399 Epoch 28/50 60000/60000 [==============================] - 13s 223us/step - loss: 150.3244 Epoch 29/50 60000/60000 [==============================] - 13s 220us/step - loss: 150.2228 Epoch 30/50 60000/60000 [==============================] - 13s 220us/step - loss: 150.1280 Epoch 31/50 60000/60000 [==============================] - 13s 218us/step - loss: 150.0228 Epoch 32/50 60000/60000 [==============================] - 13s 214us/step - loss: 149.9021 Epoch 33/50 60000/60000 [==============================] - 13s 214us/step - loss: 149.8374 Epoch 34/50 60000/60000 [==============================] - 14s 241us/step - loss: 149.7349 Epoch 35/50 60000/60000 [==============================] - 14s 231us/step - loss: 149.6855 Epoch 36/50 60000/60000 [==============================] - 16s 267us/step - loss: 149.6180 Epoch 37/50 60000/60000 [==============================] - 16s 269us/step - loss: 149.5229 Epoch 38/50 60000/60000 [==============================] - 14s 240us/step - loss: 149.4495 Epoch 39/50 60000/60000 [==============================] - 16s 264us/step - loss: 149.3910 Epoch 40/50 60000/60000 [==============================] - 15s 245us/step - loss: 149.3203 Epoch 41/50 60000/60000 [==============================] - 14s 239us/step - loss: 149.2329 Epoch 42/50 60000/60000 [==============================] - 14s 237us/step - loss: 149.1760 Epoch 43/50 60000/60000 [==============================] - 16s 259us/step - loss: 149.1111 Epoch 44/50 60000/60000 [==============================] - 16s 274us/step - loss: 149.0437 Epoch 45/50 60000/60000 [==============================] - 17s 278us/step - loss: 148.9833 Epoch 46/50 60000/60000 [==============================] - 17s 278us/step - loss: 148.8811 Epoch 47/50 60000/60000 [==============================] - 17s 277us/step - loss: 148.8191 Epoch 48/50 60000/60000 [==============================] - 14s 240us/step - loss: 148.7496 Epoch 49/50 60000/60000 [==============================] - 15s 254us/step - loss: 148.6737 Epoch 50/50 60000/60000 [==============================] - 15s 257us/step - loss: 148.6208
In [84]:
fig, ax = plt.subplots(figsize=golden_figsize(6))
for mc_sample_size, history in zip(mc_sample_sizes, histories3):
plot_fit_history(history,
batch_size=30,
mc_sample_size=mc_sample_size,
ax=ax)
plt.show()
In [ ]:
histories_all = [
histories1,
histories2,
histories3
]
In [68]:
batch_sizes = [100, 65, 30]
In [86]:
fig, axes = plt.subplots(ncols=3, figsize=(12, 3))
fig.tight_layout()
ymin = None
ymax = None
for i, ax in enumerate(axes):
for mc_sample_size, history in zip(mc_sample_sizes,
histories_all[i]):
loss = history.history.get('loss')
curr_min = np.min(loss)
curr_max = np.max(loss)
if ymin is None or curr_min < ymin:
ymin = curr_min
if ymax is None or curr_max > ymax:
ymax = curr_max
plot_fit_history(history,
batch_size=batch_sizes[i],
mc_sample_size=mc_sample_size,
ax=ax)
for ax in axes:
ax.set_ylim(.99*ymin, 1.01*ymax)
plt.savefig('../../images/vae/nelbo_batch_vs_mc_sample_sizes.svg', format='svg')
plt.show()