Notes (old posts, page 1)

Variational Inference with Implicit Approximate Inference Models (WIP)

In [69]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [70]:
import matplotlib.pyplot as plt
import seaborn as sns

import numpy as np
import keras.backend as K

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit
from mpl_toolkits.mplot3d import Axes3D

from IPython.display import SVG
In [71]:
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [74]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [75]:
K.tf.__version__
Out[75]:
'1.2.1'
In [76]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.
In [77]:
w_min, w_max = -5, 5
In [78]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [79]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[79]:
(300, 300, 2)
In [80]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [81]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[81]:
(300, 300)
In [82]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [83]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [84]:
X = np.vstack((x1, x2, x3))
X.shape
Out[84]:
(3, 2)
In [85]:
y1 = 1
y2 = 1
y3 = 0
In [86]:
y = np.stack((y1, y2, y3))
y.shape
Out[86]:
(3,)
In [87]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [88]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[88]:
(300, 300, 3)
In [89]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [90]:
fig, ax = plt.subplots(figsize=(6, 5))

c = ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.colorbar(c)
plt.show()
In [91]:
fig, ax = plt.subplots(figsize=(6, 5))

c = ax.contourf(w1, w2, 
                np.exp(log_prior+np.sum(llhs, axis=2)), 
                cmap=plt.cm.magma)
ax.plot(*np.vstack((x1,x2,x3)).T, 'ro')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.colorbar(c)
plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [92]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy')
In [93]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [94]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[94]:
G 140717329860592 dense_6_input: InputLayerinput:output:(None, 2)(None, 2)140717330379888 dense_6: Denseinput:output:(None, 2)(None, 10)140717329860592->140717330379888 140717330379944 dense_7: Denseinput:output:(None, 10)(None, 20)140717330379888->140717330379944 140717329859024 logit: Denseinput:output:(None, 20)(None, 1)140717330379944->140717329859024 140717329765656 activation_2: Activationinput:output:(None, 1)(None, 1)140717329859024->140717329765656
In [95]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [96]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()