Variational Inference with Implicit Approximate Inference Models - @fhuszar's Explaining Away Example Pt. 1 (WIP)

In [1]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
In [70]:
import numpy as np
import keras.backend as K

import matplotlib.pyplot as plt
import seaborn as sns

from scipy.stats import logistic, multivariate_normal, norm
from scipy.special import expit

from keras.models import Model, Sequential
from keras.layers import Activation, Add, Dense, Dot, Input
from keras.optimizers import Adam
from keras.utils.vis_utils import model_to_dot

from mpl_toolkits.mplot3d import Axes3D
from matplotlib.animation import FuncAnimation

from IPython.display import HTML, SVG, display_html
from tqdm import tnrange, tqdm_notebook
In [3]:
# display animation inline
plt.rc('animation', html='html5')
plt.style.use('seaborn-notebook')
sns.set_context('notebook')
In [4]:
np.set_printoptions(precision=2,
                    edgeitems=3,
                    linewidth=80,
                    suppress=True)
In [5]:
K.tf.__version__
Out[5]:
'1.2.1'
In [6]:
LATENT_DIM = 2
NOISE_DIM = 3
BATCH_SIZE = 200
PRIOR_VARIANCE = 2.
LEARNING_RATE = 3e-3
PRETRAIN_EPOCHS = 60

Bayesian Logistic Regression (Synthetic Data)

In [7]:
z_min, z_max = -5, 5
In [8]:
z1, z2 = np.mgrid[z_min:z_max:300j, z_min:z_max:300j]
In [9]:
z_grid = np.dstack((z1, z2))
z_grid.shape
Out[9]:
(300, 300, 2)
In [10]:
prior = multivariate_normal(mean=np.zeros(LATENT_DIM), 
                            cov=PRIOR_VARIANCE)
In [11]:
log_prior = prior.logpdf(z_grid)
log_prior.shape
Out[11]:
(300, 300)
In [13]:
np.allclose(log_prior, 
            -.5*np.sum(z_grid**2, axis=2)/PRIOR_VARIANCE \
            -np.log(2*np.pi*PRIOR_VARIANCE))
Out[13]:
True
In [15]:
fig, ax = plt.subplots(figsize=(5, 5))

ax.contourf(z1, z2, log_prior, cmap='magma')

ax.set_xlabel('$w_1$')
ax.set_ylabel('$w_2$')

ax.set_xlim(z_min, z_max)
ax.set_ylim(z_min, z_max)

plt.show()
In [16]:
x = np.array([0, 5, 8, 12, 50])
In [37]:
def log_likelihood(z, x, beta_0=3., beta_1=1.):
    beta = beta_0 + np.sum(beta_1*np.maximum(0, z**3), axis=-1)
    return -np.log(beta) - x/beta
In [44]:
llhs = log_likelihood(z_grid, x.reshape(-1, 1, 1))
llhs.shape
Out[44]:
(5, 300, 300)
In [59]:
fig, axes = plt.subplots(ncols=len(x), nrows=1, figsize=(20, 4))
fig.tight_layout()

for i, ax in enumerate(axes):
    
    ax.contourf(z1, z2, llhs[i,::,::], cmap=plt.cm.magma)

    ax.set_xlim(z_min, z_max)
    ax.set_ylim(z_min, z_max)
    
    ax.set_title('$p(x = {{{0}}} \mid z)$'.format(x[i]))
    ax.set_xlabel('$z_1$')    
    
    if not i:
        ax.set_ylabel('$z_2$')

plt.show()