import tensorflow.python as tf from tensorflow.python import keras as k from tensorflow.python.keras import layers as kl from tensorflow.python.keras import activations as ka import matplotlib.pyplot as plt import numpy as np from scipy.stats import norm from scipy.special import expit
config = tf.ConfigProto() config.gpu_options.allow_growth = True k.backend.set_session(tf.Session(config=config))
(x_train, y_train), (x_test, y_test) = k.datasets.fashion_mnist.load_data() x_train = np.expand_dims(x_train, -1) / 255. x_test = np.expand_dims(x_test, -1) / 255.
image_size = 28 input_shape = (image_size, image_size, 1) batch_size = 100 kernel_size = 3 filters = 16 latent_dim = 2 epochs = 30 tf.set_random_seed(9102)
def encoder_fn(inputs, filters): x = inputs for i in range(2): filters *= 2 x = kl.Conv2D(filters=filters, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(x) x = kl.Flatten()(x) x = kl.Dense(32, activation='relu')(x) μ = kl.Dense(latent_dim)(x) σ = kl.Dense(latent_dim)(x) return μ, σ
def sampling(args): """ 重参数技巧 """ μ, σ = args ε = tf.random_normal(shape=tf.shape(μ)) return μ + tf.exp(σ / 2) * ε
def decoder_fn(z, filters): x = kl.Dense(7 * 7 * 32, activation='relu')(z) x = kl.Reshape((7, 7, 32))(x) for i in range(2): x = kl.Conv2DTranspose(filters=filters, kernel_size=kernel_size, activation='relu', strides=2, padding='same')(x) filters //= 2 x = kl.Conv2DTranspose(1, kernel_size, activation=None, padding='same')(x) return x
def loss_fn(inputs, outputs, μ, σ): xent_loss = tf.reduce_sum(tf.nn.sigmoid_cross_entropy_with_logits(labels=x_in, logits=x_out), axis=[1, 2, 3]) kl_loss = - 0.5 * tf.reduce_sum(1 + σ - tf.square(μ) - tf.exp(σ), axis=-1) vae_loss = tf.reduce_mean(xent_loss + kl_loss) return vae_loss
x_in = k.Input(shape=(image_size, image_size, 1)) μ, σ = encoder_fn(x_in, filters) z = kl.Lambda(sampling, output_shape=(latent_dim,))([μ, σ])
latent_inputs = k.Input(shape=(latent_dim,), dtype=tf.float32) outputs = decoder_fn(latent_inputs, filters) decoder = k.Model(latent_inputs, outputs) x_out = decoder(z)
encoder = k.Model(x_in, μ)
vae = k.Model(x_in, x_out) vae.add_loss(loss_fn(x_in, x_out, μ, σ)) vae.compile(k.optimizers.Nadam(0.001)) vae.fit(x=x_train, batch_size=batch_size, epochs=epochs, shuffle=True, validation_data=(x_test, None))
x_test_encoded = encoder.predict(x_test, batch_size=batch_size) plt.figure(figsize=(6, 6)) plt.scatter(x_test_encoded[:, 0], x_test_encoded[:, 1], c=y_test) plt.colorbar() plt.show()
n = 15 figure = np.zeros((image_size * n, image_size * n))
grid_x = norm.ppf(np.linspace(0.05, 0.95, n)) grid_y = norm.ppf(np.linspace(0.05, 0.95, n))
for i, yi in enumerate(grid_x): for j, xi in enumerate(grid_y): z_sample = np.array([[xi, yi]]) x_decoded = expit(decoder.predict(z_sample)) digit = x_decoded[0].reshape(image_size, image_size) figure[i * image_size: (i + 1) * image_size, j * image_size: (j + 1) * image_size] = digit
plt.figure(figsize=(10, 10)) plt.imshow(figure, cmap='Greys_r') plt.show()
|