Exploring the Google QuickDraw Dataset with SketchRNN (Part 3)

t-SNE Visualization of Sheep Sketches

This is the third part in a series of notes on my exploration of the recently released Google QuickDraw dataset 1, using the concurrently released SketchRNN model.

The QuickDraw dataset is curated from the millions of drawings contributed by over 15 million people around the world who participated in the "Quick, Draw!" A.I. Experiment, in which they were given the challenge of drawing objects belonging to a particular class (such as "cat") in under 20 seconds.

SketchRNN is an impressive generative model that was trained to produce vector drawings using this dataset. It was of particular interest to me because it cleverly assembles many of the latest tools and techniques recently developed in machine learning, such as Variational Autoencoders, HyperLSTMs (a HyperNetwork for LSTM), Autoregressive models, Layer Normalization, Recurrent Dropout, the Adam optimizer, among others.

Again, I've discarded the markdown cells or codeblocks that were intended to explain or demonstrate something, retaining only the code I need to run the experiments in this notebook. Everything up to the section Principal Component Analysis in the Latent Space was copied directly from previous notebooks. Feel free to skip right ahead to that section, as that is where the really interesting analysis happens. Everything before was mostly utility functions to facilitate visualization. Here are links to the first and second note.

These notebooks were derived from the notebook included with the code release. I've made significant stylistic changes and some minor changes to ensure Python 3 forward compatibility2.


  1. This is somewhat misleading as we are mainly exploring the Aaron Koblin Sheep Market (aaron-sheep) dataset, a smaller lightweight dataset provided with the sketch-rnn release, along with a notebook that demos various models already pre-trained on this dataset. It was a natural starting point for experimenting with sketch-rnn. Since the dataset schema is the same as that of the QuickDraw dataset, the experiments performed here on this dataset are done without loss of generality.

  2. Magenta only supports Python 2 currently.

In [2]:
%matplotlib inline
%config InlineBackend.figure_format = 'svg'
%load_ext autoreload
%autoreload 2
In [3]:
import matplotlib.pyplot as plt
import matplotlib.patches as patches

import numpy as np
import tensorflow as tf

from matplotlib.animation import FuncAnimation
from matplotlib.path import Path
from matplotlib import rc

from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

from itertools import product
from six.moves import map, zip
In [5]:
from magenta.models.sketch_rnn.sketch_rnn_train import \
    (load_env,
     load_checkpoint,
     reset_graph,
     download_pretrained_models,
     PRETRAINED_MODELS_URL)
from magenta.models.sketch_rnn.model import Model, sample
from magenta.models.sketch_rnn.utils import (lerp,
                                             slerp,
                                             get_bounds, 
                                             to_big_strokes,
                                             to_normal_strokes)
In [6]:
# For inine display of animation
# equivalent to rcParams['animation.html'] = 'html5'
rc('animation', html='html5')
In [5]:
# set numpy output to something sensible
np.set_printoptions(precision=8, 
                    edgeitems=6, 
                    linewidth=200, 
                    suppress=True)
In [6]:
tf.logging.info("TensorFlow Version: {}".format(tf.__version__))
INFO:tensorflow:TensorFlow Version: 1.1.0

Getting the Pre-Trained Models and Data

In [7]:
DATA_DIR = ('http://github.com/hardmaru/sketch-rnn-datasets/'
            'raw/master/aaron_sheep/')
MODELS_ROOT_DIR = '/tmp/sketch_rnn/models'
In [8]:
DATA_DIR
Out[8]:
'http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/'
In [9]:
PRETRAINED_MODELS_URL
Out[9]:
'http://download.magenta.tensorflow.org/models/sketch_rnn.zip'
In [10]:
download_pretrained_models(
    models_root_dir=MODELS_ROOT_DIR,
    pretrained_models_url=PRETRAINED_MODELS_URL)
INFO:tensorflow:Downloading pretrained models from http://download.magenta.tensorflow.org/models/sketch_rnn.zip...
INFO:tensorflow:Download complete.
INFO:tensorflow:Unzipping /tmp/sketch_rnn/models/sketch_rnn.zip...
INFO:tensorflow:Unzipping complete.

We look at the layer normalized model trained on the aaron_sheep dataset for now.

In [11]:
MODEL_DIR = MODELS_ROOT_DIR + '/aaron_sheep/layer_norm'
In [12]:
(train_set, 
 valid_set, 
 test_set, 
 hps_model, 
 eval_hps_model, 
 sample_hps_model) = load_env(DATA_DIR, MODEL_DIR)
INFO:tensorflow:Downloading http://github.com/hardmaru/sketch-rnn-datasets/raw/master/aaron_sheep/aaron_sheep.npz
INFO:tensorflow:Loaded 7400/300/300 from aaron_sheep.npz
INFO:tensorflow:Dataset combined: 8000 (7400/300/300), avg len 125
INFO:tensorflow:model_params.max_seq_len 250.
total images <= max_seq_len is 7400
total images <= max_seq_len is 300
total images <= max_seq_len is 300
INFO:tensorflow:normalizing_scale_factor 18.5198.
In [13]:
class SketchPath(Path):
    
    def __init__(self, data, factor=.2, *args, **kwargs):
        
        vertices = np.cumsum(data[::, :-1], axis=0) / factor
        codes = np.roll(self.to_code(data[::,-1].astype(int)), 
                        shift=1)
        codes[0] = Path.MOVETO

        super(SketchPath, self).__init__(vertices, 
                                         codes, 
                                         *args, 
                                         **kwargs)
        
    @staticmethod
    def to_code(cmd):
        # if cmd == 0, the code is LINETO
        # if cmd == 1, the code is MOVETO (which is LINETO - 1)
        return Path.LINETO - cmd
In [14]:
def draw(sketch_data, factor=.2, pad=(10, 10), ax=None):

    if ax is None:
        ax = plt.gca()

    x_pad, y_pad = pad
    
    x_pad //= 2
    y_pad //= 2
        
    x_min, x_max, y_min, y_max = get_bounds(data=sketch_data,
                                            factor=factor)

    ax.set_xlim(x_min-x_pad, x_max+x_pad)
    ax.set_ylim(y_max+y_pad, y_min-y_pad)

    sketch = SketchPath(sketch_data)

    patch = patches.PathPatch(sketch, facecolor='none')
    ax.add_patch(patch)

Load pre-trained models

In [15]:
# construct the sketch-rnn model here:
reset_graph()
model = Model(hps_model)
eval_model = Model(eval_hps_model, reuse=True)
sample_model = Model(sample_hps_model, reuse=True)
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = True.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
INFO:tensorflow:Model using gpu.
INFO:tensorflow:Input dropout mode = False.
INFO:tensorflow:Output dropout mode = False.
INFO:tensorflow:Recurrent dropout mode = False.
In [16]:
sess = tf.InteractiveSession()
sess.run(tf.global_variables_initializer())
In [17]:
# loads the weights from checkpoint into our model
load_checkpoint(sess=sess, checkpoint_path=MODEL_DIR)
INFO:tensorflow:Loading model /tmp/sketch_rnn/models/aaron_sheep/layer_norm/vector.
INFO:tensorflow:Restoring parameters from /tmp/sketch_rnn/models/aaron_sheep/layer_norm/vector
In [18]:
def encode(input_strokes):
    strokes = to_big_strokes(input_strokes).tolist()
    strokes.insert(0, [0, 0, 1, 0, 0])
    seq_len = [len(input_strokes)]
    z = sess.run(eval_model.batch_z,
                 feed_dict={
                    eval_model.input_data: [strokes], 
                    eval_model.sequence_lengths: seq_len})[0]
    return z
In [19]:
def decode(z_input=None, temperature=.1, factor=.2):
    z = None
    if z_input is not None:
        z = [z_input]
    sample_strokes, m = sample(
        sess, 
        sample_model, 
        seq_len=eval_model.hps.max_seq_len, 
        temperature=temperature, z=z)
    return to_normal_strokes(sample_strokes)

Exploring the Latent Space with Principal Component Analysis

What do you call a baby eigensheep? A lamb, duh.

We encode all of the sketches in the test set into their learned 128-dimensional latent space representations.

In [20]:
Z = np.vstack(map(encode, test_set.strokes))
Z.shape
Out[20]:
(300, 128)

Then, we find the top two principal axes that represent the direction of maximum variance in the data encoded in the latent space.

In [22]:
pca = PCA(n_components=2)
In [23]:
pca.fit(Z)
Out[23]:
PCA(copy=True, iterated_power='auto', n_components=2, random_state=None,
  svd_solver='auto', tol=0.0, whiten=False)

The two components each account for about 2% of the variance

In [24]:
pca.explained_variance_ratio_
Out[24]:
array([ 0.02140247,  0.02067117])

Let's project the data from the 128-dimensional latent space to the lower 2-dimensional space spanned by the first 2 principal components

In [25]:
Z_pca = pca.transform(Z)
Z_pca.shape
Out[25]:
(300, 2)
In [26]:
fig, ax = plt.subplots()

ax.scatter(*Z_pca.T)

ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')

plt.show()

We'd like to visualize the original sketches at their corresponding points on this plot. Each point corresponds to the latent code of a sketch, reduced to 2 dimensions. However, the plot is slightly too dense to fit sufficiently large sketches without overlapping them. Therefore, we restrict our attention to a smaller region that encompasses 80% of the data points, discarding those outside of the 5th and 95th percentiles in both axes. The blue shaded rectangle highlights our region of interest.

In [106]:
((pc1_min, pc2_min), 
 (pc1_max, pc2_max)) = np.percentile(Z_pca, q=[5, 95], axis=0)
In [107]:
roi_rect = patches.Rectangle(xy=(pc1_min, pc2_min),
                             width=pc1_max-pc1_min,
                             height=pc2_max-pc2_min, alpha=.4)
In [108]:
fig, ax = plt.subplots()

ax.scatter(*Z_pca.T)
ax.add_patch(roi_rect)

ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')

plt.show()
In [109]:
fig, ax = plt.subplots(figsize=(10, 10))

ax.set_xlim(pc1_min, pc1_max)
ax.set_ylim(pc2_min, pc2_max)

for i, sketch in enumerate(test_set.strokes):
    sketch_path = SketchPath(sketch, factor=7e+1)
    sketch_path.vertices[::,1] *= -1
    sketch_path.vertices += Z_pca[i]
    patch = patches.PathPatch(sketch_path, facecolor='none')
    ax.add_patch(patch)

ax.set_xlabel('$pc_1$')
ax.set_ylabel('$pc_2$')
    
plt.savefig("../../files/sketchrnn/aaron_sheep_pca.svg", 
            format="svg", dpi=1200)