Exploring the Google QuickDraw Dataset with SketchRNN (Part 3)
This is the third part in a series of notes on my exploration of the recently released Google QuickDraw dataset 1, using the concurrently released SketchRNN model.
The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.
SketchRNN is an impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly assembles many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, among others.
Again, I've discarded the markdown cells or codeblocks that were intended to explain or demonstrate something, retaining only the code I need to run the experiments in this notebook. Everything up to the section Principal Component Analysis in the Latent Space was copied directly from previous notebooks. Feel free to skip right ahead to that section, as that is where the really interesting analysis happens. Everything before was mostly utility functions to facilitate visualization. Here are links to the first and second note.
These notebooks were derived from the notebook included with the code release. I've made significant stylistic changes and some minor changes to ensure Python 3 forward compatibility2.
This is somewhat misleading as we are mainly exploring the Aaron Koblin Sheep Market (aaron-sheep) dataset, a smaller lightweight dataset provided with the
sketch-rnn
release, along with a notebook that demos various models already pre-trained on this dataset. It was a natural starting point for experimenting withsketch-rnn
. Since the dataset schema is the same as that of the QuickDraw dataset, the experiments performed here on this dataset are done without loss of generality.↩Magenta only supports Python 2 currently.↩
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import numpy as np
import tensorflow as tf
from matplotlib.animation import FuncAnimation
from matplotlib.path import Path
from matplotlib import rc
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE
from itertools import product
from six.moves import map, zip
from magenta.models.sketch_rnn.sketch_rnn_train import \
(load_env,
load_checkpoint,
reset_graph,
download_pretrained_models,
PRETRAINED_MODELS_URL)
from magenta.models.sketch_rnn.model import Model, sample
from magenta.models.sketch_rnn.utils import (lerp,
slerp,
get_bounds,
to_big_strokes,
to_normal_strokes)
# For inine display of animation
# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')
# set numpy output to something sensible
np.set_printoptions(precision=8,
edgeitems=6,
linewidth=200,
suppress=True)
tf.logging.info("TensorFlow Version: {}".format(tf.__version__))
Getting the Pre-Trained Models and Data¶
DATA_DIR = ('http://github.com/hardmaru/sketch-rnn-datasets/'
'raw/master/aaron_sheep/')
MODELS_ROOT_DIR = '/tmp/sketch_rnn/models'
DATA_DIR
PRETRAINED_MODELS_URL
download_pretrained_models(
models_root_dir=MODELS_ROOT_DIR,
pretrained_models_url=PRETRAINED_MODELS_URL)
We look at the layer normalized model trained on the aaron_sheep
dataset for now.
MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'
(train_set,
valid_set,
test_set,
hps_model,
eval_hps_model,
sample_hps_model) = load_env(DATA_DIR, MODEL_DIR)
class SketchPath(Path):
def __init__(self, data, factor=.2, *args, **kwargs):
vertices = np.cumsum(data[::, :-1], axis=0) / factor
codes = np.roll(self.to_code(data[::,-1].astype(int)),
shift=1)
codes[0] = Path.MOVETO
super(SketchPath, self).__init__(vertices,
codes,
*args,
**kwargs)
@staticmethod
def to_code(cmd):
# if cmd == 0, the code is LINETO
# if cmd == 1, the code is MOVETO (which is LINETO - 1)
return Path.LINETO - cmd
def draw(sketch_data, factor=.2, pad=(10, 10), ax=None):
if ax is None:
ax = plt.gca()
x_pad, y_pad = pad
x_pad //= 2
y_pad //= 2
x_min, x_max, y_min, y_max = get_bounds(data=sketch_data,
factor=factor)
ax.set_xlim(x_min-x_pad, x_max+x_pad)
ax.set_ylim(y_max+y_pad, y_min-y_pad)
sketch = SketchPath(sketch_data)
patch = patches.PathPatch(sketch, facecolor='none')
ax.add_patch(patch)
Load pre-trained models¶
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
# loads the weights from checkpoint into our model
load_checkpoint(sess=sess, checkpoint_path=MODEL_DIR)
def encode(input_strokes):
strokes = to_big_strokes(input_strokes).tolist()
strokes.insert(0, [0, 0, 1, 0, 0])
seq_len = [len(input_strokes)]
z = sess.run(eval_model.batch_z,
feed_dict={
eval_model.input_data: [strokes],
eval_model.sequence_lengths: seq_len})[0]
return z
def decode(z_input=None, temperature=.1, factor=.2):
z = None
if z_input is not None:
z = [z_input]
sample_strokes, m = sample(
sess,
sample_model,
seq_len=eval_model.hps.max_seq_len,
temperature=temperature, z=z)
return to_normal_strokes(sample_strokes)
Exploring the Latent Space with Principal Component Analysis¶
What do you call a baby eigensheep? A lamb, duh.
We encode all of the sketches in the test set into their learned 128-dimensional latent space representations.
Z = np.vstack(map(encode, test_set.strokes))
Z.shape
Then, we find the top two principal axes that represent the direction of maximum variance in the data encoded in the latent space.
pca = PCA(n_components=2)
pca.fit(Z)
The two components each account for about 2% of the variance
pca.explained_variance_ratio_
Let's project the data from the 128-dimensional latent space to the lower 2-dimensional space spanned by the first 2 principal components
Z_pca = pca.transform(Z)
Z_pca.shape
fig, ax = plt.subplots()
ax.scatter(*Z_pca.T)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
We'd like to visualize the original sketches at their corresponding points on this plot. Each point corresponds to the latent code of a sketch, reduced to 2 dimensions. However, the plot is slightly too dense to fit sufficiently large sketches without overlapping them. Therefore, we restrict our attention to a smaller region that encompasses 80% of the data points, discarding those outside of the 5th and 95th percentiles in both axes. The blue shaded rectangle highlights our region of interest.
((pc1_min, pc2_min),
(pc1_max, pc2_max)) = np.percentile(Z_pca, q=[5, 95], axis=0)
roi_rect = patches.Rectangle(xy=(pc1_min, pc2_min),
width=pc1_max-pc1_min,
height=pc2_max-pc2_min, alpha=.4)
fig, ax = plt.subplots()
ax.scatter(*Z_pca.T)
ax.add_patch(roi_rect)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.show()
fig, ax = plt.subplots(figsize=(10, 10))
ax.set_xlim(pc1_min, pc1_max)
ax.set_ylim(pc2_min, pc2_max)
for i, sketch in enumerate(test_set.strokes):
sketch_path = SketchPath(sketch, factor=7e+1)
sketch_path.vertices[::,1] *= -1
sketch_path.vertices += Z_pca[i]
patch = patches.PathPatch(sketch_path, facecolor='none')
ax.add_patch(patch)
ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
plt.savefig("../../files/sketchrnn/aaron_sheep_pca.svg",
format="svg", dpi=1200)