Variational Inference with Implicit Approximate Inference Models (WIP Pt. 2)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [2]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Dense, Dot, Input
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import SVG
Using TensorFlow backend.
/Users/tiao/.virtualenvs/implicit/lib/python3.6/site-packages/IPython/html.py:14: ShimWarning: The `IPython.html` package has been deprecated since IPython 4.0. You should import from `notebook` instead. `IPython.html.widgets` has moved to `ipywidgets`.
  "`IPython.html.widgets` has moved to `ipywidgets`.", ShimWarning)
In [3]:
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
# display animation inline
plt.rc('animation', html='html5')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 128
D_BATCH_SIZE = 128
G_BATCH_SIZE = 128
PRIOR_VARIANCE = 2.

Bayesian Logistic Regression (Synthetic Data)

In [86]:
w_min, w_max = -5, 5
In [87]:
w1, w2 = np.mgrid[w_min:w_max:300j, w_min:w_max:300j]
In [88]:
w_grid = np.dstack((w1, w2))
w_grid.shape
Out[88]:
(300, 300, 2)
In [89]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [90]:
log_prior = prior.logpdf(w_grid)
log_prior.shape
Out[90]:
(300, 300)
In [91]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [92]:
x1 = np.array([ 1.5,  1.])
x2 = np.array([-1.5,  1.])
x3 = np.array([- .5, -1.])
In [93]:
X = np.vstack((x1, x2, x3))
X.shape
Out[93]:
(3, 2)
In [94]:
y1 = 1
y2 = 1
y3 = 0
In [95]:
y = np.stack((y1, y2, y3))
y.shape
Out[95]:
(3,)
In [96]:
def log_likelihood(w, x, y):
    # equiv. to negative binary cross entropy
    return np.log(expit(np.dot(w.T, x)*(-1)**(1-y)))
In [97]:
llhs = log_likelihood(w_grid.T, X.T, y)
llhs.shape
Out[97]:
(300, 300, 3)
In [98]:
fig, axes = plt.subplots(ncols=3, nrows=1, figsize=(6, 2))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(w1, w2, llhs[::,::,i], cmap=plt.cm.magma)

    ax.set_xlim(w_min, w_max)
    ax.set_ylim(w_min, w_max)
    
    ax.set_title('$p(y_{{{0}}} \mid x_{{{0}}}, w)$'.format(i+1))
    ax.set_xlabel('$w_1$')    
    
    if not i:
        ax.set_ylabel('$w_2$')

plt.show()
In [104]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, np.sum(llhs, axis=2), 
                cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [105]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.plot(*np.vstack((x1,x2,x3)).T, 'ro')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()

Model Definitions

Density Ratio Estimator (Discriminator) Model

$T_{\psi}(x, z)$

Here we consider

$T_{\psi}(w)$

$T_{\psi} : \mathbb{R}^2 \to \mathbb{R}$

In [106]:
discriminator = Sequential(name='discriminator')
discriminator.add(Dense(10, input_dim=LATENT_DIM, activation='relu'))
discriminator.add(Dense(20, activation='relu'))
discriminator.add(Dense(1, activation=None, name='logit'))
discriminator.add(Activation('sigmoid'))
discriminator.compile(optimizer='adam',
                      loss='binary_crossentropy',
                      metrics=['binary_accuracy'])
In [107]:
ratio_estimator = Model(
    inputs=discriminator.inputs, 
    outputs=discriminator.get_layer(name='logit').output)
In [108]:
SVG(model_to_dot(discriminator, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[108]:
G 4390601896 dense_6_input: InputLayerinput:output:(None, 2)(None, 2)4656746112 dense_6: Denseinput:output:(None, 2)(None, 10)4390601896->4656746112 4656743088 dense_7: Denseinput:output:(None, 10)(None, 20)4656746112->4656743088 4705728888 logit: Denseinput:output:(None, 20)(None, 1)4656743088->4705728888 4705708240 activation_2: Activationinput:output:(None, 1)(None, 1)4705728888->4705708240
In [109]:
w_grid_ratio = ratio_estimator.predict(w_grid.reshape(300*300, 2))
w_grid_ratio = w_grid_ratio.reshape(300, 300)

Initial density ratio, prior to any training

In [110]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, w_grid_ratio, cmap=plt.cm.magma)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()
In [111]:
discriminator.evaluate(prior.rvs(size=5), np.zeros(5))
5/5 [==============================] - 0s
Out[111]:
[0.45842784643173218, 1.0]

Approximate Inference Model

$z_{\phi}(x, \epsilon)$

Here we only consider

$z_{\phi}(\epsilon)$

$z_{\phi}: \mathbb{R}^3 \to \mathbb{R}^2$

In [112]:
inference = Sequential()
inference.add(Dense(10, input_dim=NOISE_DIM, activation='relu'))
inference.add(Dense(20, activation='relu'))
inference.add(Dense(LATENT_DIM, activation=None))
inference.summary()
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
=================================================================
dense_8 (Dense)              (None, 10)                40        
_________________________________________________________________
dense_9 (Dense)              (None, 20)                220       
_________________________________________________________________
dense_10 (Dense)             (None, 2)                 42        
=================================================================
Total params: 302
Trainable params: 302
Non-trainable params: 0
_________________________________________________________________

The variational parameters $\phi$ are the trainable weights of the approximate inference model

In [113]:
phi = inference.trainable_weights
phi
Out[113]:
[<tf.Variable 'dense_8/kernel:0' shape=(3, 10) dtype=float32_ref>,
 <tf.Variable 'dense_8/bias:0' shape=(10,) dtype=float32_ref>,
 <tf.Variable 'dense_9/kernel:0' shape=(10, 20) dtype=float32_ref>,
 <tf.Variable 'dense_9/bias:0' shape=(20,) dtype=float32_ref>,
 <tf.Variable 'dense_10/kernel:0' shape=(20, 2) dtype=float32_ref>,
 <tf.Variable 'dense_10/bias:0' shape=(2,) dtype=float32_ref>]
In [114]:
SVG(model_to_dot(inference, show_shapes=True)
    .create(prog='dot', format='svg'))
Out[114]:
G 4638359224 dense_8_input: InputLayerinput:output:(None, 3)(None, 3)4706841152 dense_8: Denseinput:output:(None, 3)(None, 10)4638359224->4706841152 4638452536 dense_9: Denseinput:output:(None, 10)(None, 20)4706841152->4638452536 4638609928 dense_10: Denseinput:output:(None, 20)(None, 2)4638452536->4638609928
In [115]:
eps = np.random.randn(BATCH_SIZE, NOISE_DIM)
In [116]:
w_posterior_samples = inference.predict(eps)
w_posterior_samples.shape
Out[116]:
(128, 2)
In [117]:
w_prior_samples = prior.rvs(size=BATCH_SIZE)
w_prior_samples.shape
Out[117]:
(128, 2)
In [120]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(w1, w2, 
            np.exp(log_prior+np.sum(llhs, axis=2)), 
            cmap=plt.cm.magma)

ax.scatter(*w_posterior_samples.T, alpha=.6)

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(w_min, w_max)
ax.set_ylim(w_min, w_max)

plt.show()