import numpy as np import tensorflow.python as tf import tensorflow.python.keras as tfk import tensorflow.python.keras.layers as tfkl import tensorflow_probability.python as tfp import tensorflow_probability.python.layers as tfpl import tensorflow_probability.python.distributions as tfd from toolz import compose, pipe from scipy.stats import norm import matplotlib.pyplot as plt tf.enable_eager_execution()
config = tf.ConfigProto() config.gpu_options.allow_growth = True tfk.backend.set_session(tf.Session(config=config))
(x_train, y_train), (x_test, y_test) = tfk.datasets.mnist.load_data()
def _preprocess(x, y): x = tf.reshape(x, (28, 28, 1)) x = tf.cast(x, tf.float32) / 255. y = tf.one_hot(y, 10) return (x, y), ()
train_dataset = (tf.data.Dataset.from_tensor_slices((x_train, y_train)) .shuffle(int(10000)) .map(_preprocess) .batch(256) .prefetch(tf.data.experimental.AUTOTUNE))
eval_dataset = (tf.data.Dataset.from_tensor_slices((x_test, y_test)) .map(_preprocess) .batch(256) .prefetch(tf.data.experimental.AUTOTUNE))
input_shape = [28, 28, 1] encoded_size = 2 base_depth = 16 num_class = 10
y_inputs = tfk.Input(shape=(num_class), name='y_inputs')
cluster_mean = tfk.Sequential([ tfkl.InputLayer(num_class), tfkl.Dense(encoded_size * 2), tfkl.Dense(encoded_size)])
mean = cluster_mean(y_inputs)
prior = tfd.Independent(tfd.Normal(loc=mean, scale=1), reinterpreted_batch_ndims=1)
x_inputs = tfk.Input(shape=input_shape, name='x_inputs')
encoder = tfk.Sequential( [tfkl.Conv2D(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu, input_shape=input_shape), tfkl.Conv2D(base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(2 * base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(2 * base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(4 * encoded_size, 7, strides=1, padding='valid', activation=tf.nn.leaky_relu), tfkl.Flatten(), tfkl.Dense(tfpl.MultivariateNormalTriL.params_size(encoded_size), activation=None), tfpl.MultivariateNormalTriL( encoded_size, activity_regularizer=tfpl.KLDivergenceRegularizer(prior)), ])
encoder_outputs = encoder(x_inputs)
decoder_input = tfk.Input((encoded_size), name='decoder_input')
decoder = tfk.Sequential([ tfkl.Reshape((1, 1, encoded_size), input_shape=[encoded_size]), tfkl.Conv2DTranspose(2 * base_depth, 7, strides=1, padding='valid', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(2 * base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(2 * base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(base_depth, 5, strides=2, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2DTranspose(base_depth, 5, strides=1, padding='same', activation=tf.nn.leaky_relu), tfkl.Conv2D(filters=1, kernel_size=5, strides=1, padding='same', activation=tfk.activations.sigmoid), ])
decoder_outputs = decoder(encoder_outputs) restruct_loss = tfk.backend.sum(tfk.backend.binary_crossentropy(x_inputs, decoder_outputs))
vae = tfk.Model([x_inputs, y_inputs], [decoder_outputs, mean]) vae.add_loss(restruct_loss) vae.compile('adam') vae.summary()
vae.fit(train_dataset, epochs=15, validation_data=eval_dataset, verbose=0)
class_mean = cluster_mean(np.eye(num_class))
n = 15 digit_size = 28 figure = np.zeros((digit_size * n, digit_size * n))
digit = 8
grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) + class_mean[digit][1] grid_y = norm.ppf(np.linspace(0.05, 0.95, n)) + class_mean[digit][0]
for i, yi in enumerate(grid_x): for j, xi in enumerate(grid_y): sample = np.array([[xi, yi]]) x_decoded = decoder.predict(sample) digit = x_decoded[0].reshape(digit_size, digit_size) figure[i * digit_size: (i + 1) * digit_size, j * digit_size: (j + 1) * digit_size] = digit
plt.figure(figsize=(10, 10)) plt.imshow(figure, cmap='Greys_r') plt.show()
|