import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/",one_hot=True)
plt.imshow(mnist.train.images[28].reshape(28,28),cmap='Greys')
def generator(z,reuse=None):
with tf.variable_scope('gen',reuse=reuse):
hidden1 = tf.layers.dense(inputs=z,units=128)
# Leaky Relu
alpha = 0.01
hidden1 = tf.maximum(alpha*hidden1,hidden1)
hidden2 = tf.layers.dense(inputs=hidden1,units=128)
hidden2 = tf.maximum(alpha*hidden2,hidden2)
output = tf.layers.dense(hidden2,units=784,activation=tf.nn.tanh)
return output
def discriminator(X,reuse=None):
with tf.variable_scope('dis',reuse=reuse):
hidden1 = tf.layers.dense(inputs=X,units=128)
# Leaky Relu
alpha = 0.01
hidden1 = tf.maximum(alpha*hidden1,hidden1)
hidden2 = tf.layers.dense(inputs=hidden1,units=128)
hidden2 = tf.maximum(alpha*hidden2,hidden2)
logits = tf.layers.dense(hidden2,units=1)
output = tf.sigmoid(logits)
return output, logits
real_images = tf.placeholder(tf.float32,shape=[None,784])
z = tf.placeholder(tf.float32,shape=[None,100])
G = generator(z)
D_output_real , D_logits_real = discriminator(real_images)
D_output_fake, D_logits_fake = discriminator(G,reuse=True)
def loss_func(logits_in,labels_in):
return tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=logits_in,labels=labels_in))
D_real_loss = loss_func(D_logits_real,tf.ones_like(D_logits_real)* (0.9))
D_fake_loss = loss_func(D_logits_fake,tf.zeros_like(D_logits_real))
D_loss = D_real_loss + D_fake_loss
G_loss = loss_func(D_logits_fake,tf.ones_like(D_logits_fake))
learning_rate = 0.001
tvars = tf.trainable_variables()
d_vars = [var for var in tvars if 'dis' in var.name]
g_vars = [var for var in tvars if 'gen' in var.name]
print([v.name for v in d_vars])
print([v.name for v in g_vars])
D_trainer = tf.train.AdamOptimizer(learning_rate).minimize(D_loss, var_list=d_vars)
G_trainer = tf.train.AdamOptimizer(learning_rate).minimize(G_loss, var_list=g_vars)
batch_size = 100
epochs = 500
saver = tf.train.Saver()
# Save a sample per epoch
samples = []
with tf.Session() as sess:
sess.run(init)
# Recall an epoch is an entire run through the training data
for e in range(epochs):
# // indicates classic division
num_batches = mnist.train.num_examples // batch_size
for i in range(num_batches):
# Grab batch of images
batch = mnist.train.next_batch(batch_size)
# Get images, reshape and rescale to pass to D
batch_images = batch[0].reshape((batch_size, 784))
batch_images = batch_images*2 - 1
# Z (random latent noise data for Generator)
# -1 to 1 because of tanh activation
batch_z = np.random.uniform(-1, 1, size=(batch_size, 100))
# Run optimizers, no need to save outputs, we won't use them
_ = sess.run(D_trainer, feed_dict={real_images: batch_images, z: batch_z})
_ = sess.run(G_trainer, feed_dict={z: batch_z})
print("Currently on Epoch {} of {} total...".format(e+1, epochs))
# Sample from generator as we're training for viewing afterwards
sample_z = np.random.uniform(-1, 1, size=(1, 100))
gen_sample = sess.run(generator(z ,reuse=True),feed_dict={z: sample_z})
samples.append(gen_sample)
saver.save(sess, './models/500_epoch_model.ckpt')
img = tf.placeholder(tf.float32, shape=(1, 784))
mask = np.ones(shape=[28,28], dtype=np.float32)
mask[10:18,10:18] = 0
mask = mask.reshape(1,784)
img_to_correct = tf.multiply(tf.reshape(img, shape=(28,28)), tf.convert_to_tensor(tf.reshape(mask, shape=(28,28))))
img_gen_masked = tf.multiply(tf.reshape(G, shape=(28,28)), tf.convert_to_tensor(tf.reshape(mask, shape=(28,28))))
contextual_loss = tf.reduce_sum(tf.abs(img_gen_masked - img_to_correct), 1)
perceptual_loss = G_loss
complete_loss = contextual_loss + 0.5*perceptual_loss
complete_loss_trainer = tf.train.AdamOptimizer(learning_rate).minimize(complete_loss)
init = tf.global_variables_initializer()
import numpy as np from sklearn.preprocessing import normalize
x = np.random.rand(100)*10
norm = normalize(x[:,np.newaxis], axis=0).ravel()
print(x) print('\n') print(norm)
from sklearn.preprocessing import normalize
saver = tf.train.Saver(var_list=g_vars)
v = 0
new_samples = []
with tf.Session() as sess:
sess.run(init)
saver.restore(sess,'./models/500_epoch_model.ckpt')
zhat = np.random.uniform(-1, 1, size=(1, 100)).astype("float32")
for x in range(1000):
run = [complete_loss, complete_loss_trainer, G]
loss, g, G_img = sess.run(run, feed_dict = {z: zhat, img: mnist.train.images[28].reshape(1, 784)})
zhat = normalize(zhat[:], axis=1).ravel()
zhat = zhat.reshape(1, 100)
# zhat = np.clip(zhat, -1, 1)
with tf.Session() as sess:
img_to_corr = sess.run(img_to_correct, feed_dict = {img: mnist.train.images[28].reshape(1, 784)})
plt.imshow(img_to_corr,cmap='Greys')
ans = G_img.reshape(28,28)
plt.imshow(ans,cmap='Greys')
missing_part = ans[10:18,10:18]
img_to_corr[10:18,10:18] = missing_part
plt.imshow(img_to_corr,cmap='Greys')
missing = np.multiply((1 - mask.reshape(28, 28)), ans)
reconstructed = img_to_corr + missing
plt.imshow(reconstructed,cmap='Greys')