GAN

In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

The Data

In [2]:
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
In [4]:
plt.imshow(mnist.train.images[28].reshape(28,28),cmap='Greys')
Out[4]:
<matplotlib.image.AxesImage at 0x7f27eae162e8>

The Generator

In [5]:
def generator(z,reuse=None):
    with tf.variable_scope('gen',reuse=reuse):
        hidden1 = tf.layers.dense(inputs=z,units=128)
        # Leaky Relu
        alpha = 0.01
        hidden1 = tf.maximum(alpha*hidden1,hidden1)
        hidden2 = tf.layers.dense(inputs=hidden1,units=128)
        
        hidden2 = tf.maximum(alpha*hidden2,hidden2)
        output = tf.layers.dense(hidden2,units=784,activation=tf.nn.tanh)
        return output
    

The Discriminator

In [6]:
def discriminator(X,reuse=None):
    with tf.variable_scope('dis',reuse=reuse):
        hidden1 = tf.layers.dense(inputs=X,units=128)
        # Leaky Relu
        alpha = 0.01
        hidden1 = tf.maximum(alpha*hidden1,hidden1)
        
        hidden2 = tf.layers.dense(inputs=hidden1,units=128)
        hidden2 = tf.maximum(alpha*hidden2,hidden2)
        
        logits = tf.layers.dense(hidden2,units=1)
        output = tf.sigmoid(logits)
    
        return output, logits

Placeholders

In [7]:
real_images = tf.placeholder(tf.float32,shape=[None,784])
z = tf.placeholder(tf.float32,shape=[None,100])

Generator

In [8]:
G = generator(z)

Discriminator

In [9]:
D_output_real , D_logits_real = discriminator(real_images)
In [10]:
D_output_fake, D_logits_fake = discriminator(G,reuse=True)

Losses

In [11]:
def loss_func(logits_in,labels_in):
    return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_in,labels=labels_in))
In [12]:
D_real_loss = loss_func(D_logits_real,tf.ones_like(D_logits_real)* (0.9))
In [13]:
D_fake_loss = loss_func(D_logits_fake,tf.zeros_like(D_logits_real))
In [14]:
D_loss = D_real_loss + D_fake_loss
In [15]:
G_loss = loss_func(D_logits_fake,tf.ones_like(D_logits_fake))

Optimizers

In [16]:
learning_rate = 0.001
In [17]:
tvars = tf.trainable_variables()

d_vars = [var for var in tvars if 'dis' in var.name]
g_vars = [var for var in tvars if 'gen' in var.name]

print([v.name for v in d_vars])
print([v.name for v in g_vars])
['dis/dense/kernel:0', 'dis/dense/bias:0', 'dis/dense_1/kernel:0', 'dis/dense_1/bias:0', 'dis/dense_2/kernel:0', 'dis/dense_2/bias:0']
['gen/dense/kernel:0', 'gen/dense/bias:0', 'gen/dense_1/kernel:0', 'gen/dense_1/bias:0', 'gen/dense_2/kernel:0', 'gen/dense_2/bias:0']
In [18]:
D_trainer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss, var_list=d_vars)
G_trainer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, var_list=g_vars)

Training Session

In [19]:
batch_size = 100
epochs = 500
saver = tf.train.Saver()
In [20]:
# Save a sample per epoch
samples = []

with tf.Session() as sess:

sess.run(init)

# Recall an epoch is an entire run through the training data
for e in range(epochs):
    # // indicates classic division
    num_batches = mnist.train.num_examples // batch_size

    for i in range(num_batches):

        # Grab batch of images
        batch = mnist.train.next_batch(batch_size)

        # Get images, reshape and rescale to pass to D
        batch_images = batch[0].reshape((batch_size, 784))
        batch_images = batch_images*2 - 1

        # Z (random latent noise data for Generator)
        # -1 to 1 because of tanh activation
        batch_z = np.random.uniform(-1, 1, size=(batch_size, 100))

        # Run optimizers, no need to save outputs, we won't use them
        _ = sess.run(D_trainer, feed_dict={real_images: batch_images, z: batch_z})
        _ = sess.run(G_trainer, feed_dict={z: batch_z})


    print("Currently on Epoch {} of {} total...".format(e+1, epochs))

    # Sample from generator as we're training for viewing afterwards
    sample_z = np.random.uniform(-1, 1, size=(1, 100))
    gen_sample = sess.run(generator(z ,reuse=True),feed_dict={z: sample_z})

    samples.append(gen_sample)

    saver.save(sess, './models/500_epoch_model.ckpt')
In [21]:
img = tf.placeholder(tf.float32, shape=(1, 784))
In [22]:
mask = np.ones(shape=[28,28], dtype=np.float32)
mask[10:18,10:18] = 0
mask = mask.reshape(1,784)
In [23]:
img_to_correct = tf.multiply(tf.reshape(img, shape=(28,28)), tf.convert_to_tensor(tf.reshape(mask, shape=(28,28))))
In [24]:
img_gen_masked = tf.multiply(tf.reshape(G, shape=(28,28)), tf.convert_to_tensor(tf.reshape(mask, shape=(28,28))))
In [25]:
contextual_loss = tf.reduce_sum(tf.abs(img_gen_masked - img_to_correct), 1)
In [26]:
perceptual_loss = G_loss
In [27]:
complete_loss = contextual_loss + 0.5*perceptual_loss
In [28]:
complete_loss_trainer = tf.train.AdamOptimizer(learning_rate).minimize(complete_loss)
In [29]:
init = tf.global_variables_initializer()

import numpy as np from sklearn.preprocessing import normalize

x = np.random.rand(100)*10

norm = normalize(x[:,np.newaxis], axis=0).ravel()

print(x) print('\n') print(norm)

In [32]:
from sklearn.preprocessing import normalize
saver = tf.train.Saver(var_list=g_vars)
v = 0
new_samples = []
with tf.Session() as sess:
    sess.run(init)
    saver.restore(sess,'./models/500_epoch_model.ckpt')
    
    zhat = np.random.uniform(-1, 1, size=(1, 100)).astype("float32")
    for x in range(1000):
        run = [complete_loss, complete_loss_trainer, G]
        loss, g, G_img = sess.run(run, feed_dict = {z: zhat, img: mnist.train.images[28].reshape(1, 784)})
        zhat = normalize(zhat[:], axis=1).ravel()
        zhat = zhat.reshape(1, 100)
#         zhat = np.clip(zhat, -1, 1)
INFO:tensorflow:Restoring parameters from ./models/500_epoch_model.ckpt
In [33]:
with tf.Session() as sess:
    img_to_corr = sess.run(img_to_correct, feed_dict = {img: mnist.train.images[28].reshape(1, 784)})

Image to be corrected

In [34]:
plt.imshow(img_to_corr,cmap='Greys')
Out[34]:
<matplotlib.image.AxesImage at 0x7f27e6f4d2b0>
In [35]:
ans = G_img.reshape(28,28)

Corresponding Fake image generated

In [36]:
plt.imshow(ans,cmap='Greys')
Out[36]:
<matplotlib.image.AxesImage at 0x7f27e6f3a7b8>

Retriving missing part and substituting

In [37]:
missing_part = ans[10:18,10:18]
In [38]:
img_to_corr[10:18,10:18] = missing_part
In [39]:
plt.imshow(img_to_corr,cmap='Greys')
Out[39]:
<matplotlib.image.AxesImage at 0x7f27e6ea1f28>

Alternate method

In [40]:
missing = np.multiply((1 - mask.reshape(28, 28)), ans)
In [41]:
reconstructed = img_to_corr + missing
In [42]:
plt.imshow(reconstructed,cmap='Greys')
Out[42]:
<matplotlib.image.AxesImage at 0x7f27e6e117f0>

Finished !