diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb index 5621d6a358e..78fcd397087 100644 --- a/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb +++ b/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb @@ -1,324 +1,405 @@ { + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "name": "dcgan.ipynb", + "version": "0.3.2", + "provenance": [], + "collapsed_sections": [] + }, + "kernelspec": { + "name": "python2", + "display_name": "Python 2" + }, + "accelerator": "GPU" + }, "cells": [ { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "0TD5ZrvEMbhZ" }, + "cell_type": "markdown", "source": [ - "##### Copyright 2018 The TensorFlow Authors.\n", + "**Copyright 2018 The TensorFlow Authors**.\n", "\n", "Licensed under the Apache License, Version 2.0 (the \"License\").\n", "\n", - "# DCGAN: An example with tf.keras and eager\n", + "# Generating Handwritten Digits with DCGAN\n", "\n", - "\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\n", - " \u003cimg src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" /\u003eRun in Google Colab\u003c/a\u003e \n", - "\u003c/td\u003e\u003ctd\u003e\n", - "\u003ca target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"\u003e\u003cimg width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" /\u003eView source on GitHub\u003c/a\u003e\u003c/td\u003e\u003c/table\u003e" + "<table class=\"tfo-notebook-buttons\" align=\"left\"><td>\n", + "<a target=\"_blank\" href=\"https://colab.research.google.com/github/tensorflow/tensorflow/blob/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\">\n", + " <img src=\"https://www.tensorflow.org/images/colab_logo_32px.png\" />Run in Google Colab</a> \n", + "</td><td>\n", + "<a target=\"_blank\" href=\"https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/eager/python/examples/generative_examples/dcgan.ipynb\"><img width=32px src=\"https://www.tensorflow.org/images/GitHub-Mark-32px.png\" />View source on GitHub</a></td></table>" ] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "ITZuApL56Mny" }, + "cell_type": "markdown", "source": [ - "This notebook demonstrates how to generate images of handwritten digits using [tf.keras](https://www.tensorflow.org/programmers_guide/keras) and [eager execution](https://www.tensorflow.org/programmers_guide/eager). To do so, we use Deep Convolutional Generative Adverserial Networks ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)).\n", + "This tutorial demonstrates how to generate images of handwritten digits using a Deep Convolutional Generative Adversarial Network ([DCGAN](https://arxiv.org/pdf/1511.06434.pdf)). The code is written in [tf.keras](https://www.tensorflow.org/programmers_guide/keras) with [eager execution](https://www.tensorflow.org/programmers_guide/eager) enabled. " + ] + }, + { + "metadata": { + "colab_type": "toc", + "id": "x2McrO9bMyLN" + }, + "cell_type": "markdown", + "source": [ + ">[Generating Handwritten Digits with DCGAN](#scrollTo=0TD5ZrvEMbhZ)\n", "\n", - "This model takes about ~30 seconds per epoch (using tf.contrib.eager.defun to create graph functions) to train on a single Tesla K80 on Colab, as of July 2018.\n", + ">>[What are GANs?](#scrollTo=2MbKJY38Puy9)\n", "\n", - "Below is the output generated after training the generator and discriminator models for 150 epochs.\n", + ">>>[Import TensorFlow and enable eager execution](#scrollTo=e1_Y75QXJS6h)\n", + "\n", + ">>>[Load the dataset](#scrollTo=iYn4MdZnKCey)\n", + "\n", + ">>>[Use tf.data to create batches and shuffle the dataset](#scrollTo=PIGN6ouoQxt3)\n", + "\n", + ">>[Create the models](#scrollTo=THY-sZMiQ4UV)\n", + "\n", + ">>>[The Generator Model](#scrollTo=-tEyxE-GMC48)\n", + "\n", + ">>>[The Discriminator model](#scrollTo=D0IKnaCtg6WE)\n", + "\n", + ">>[Define the loss functions and the optimizer](#scrollTo=0FMYgY_mPfTi)\n", + "\n", + ">>>[Generator loss](#scrollTo=Jd-3GCUEiKtv)\n", + "\n", + ">>>[Discriminator loss](#scrollTo=PKY_iPSPNWoj)\n", + "\n", + ">>[Set up GANs for Training](#scrollTo=Rw1fkAczTQYh)\n", + "\n", + ">>[Train the GANs](#scrollTo=dZrd4CdjR-Fp)\n", + "\n", + ">>[Generated images](#scrollTo=P4M_vIbUi7c0)\n", + "\n", + ">>[Learn more about GANs](#scrollTo=k6qC-SbjK0yW)\n", + "\n" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "2MbKJY38Puy9" + }, + "cell_type": "markdown", + "source": [ + "## What are GANs?\n", + "GANs, or [Generative Adversarial Networks](https://arxiv.org/abs/1406.2661), are a framework for estimating generative models. Two models are trained simultaneously by an adversarial process: a Generator, which is responsible for generating data (say, images), and a Discriminator, which is responsible for estimating the probability that an image was drawn from the training data (the image is real), or was produced by the Generator (the image is fake). During training, the Generator becomes progressively better at generating images, until the Discriminator is no longer able to distinguish real images from fake. \n", + "\n", + "\n", + "\n", + "We will demonstrate this process end-to-end on MNIST. Below is an animation that shows a series of images produced by the Generator as it was trained for 50 epochs. Overtime, the generated images become increasingly difficult to distinguish from the training set.\n", + "\n", + "To learn more about GANs, we recommend MIT's [Intro to Deep Learning](http://introtodeeplearning.com/) course, which includes a lecture on Deep Generative Models ([video](https://youtu.be/JVb54xhEw6Y) | [slides](http://introtodeeplearning.com/materials/2018_6S191_Lecture4.pdf)). Now, let's head to the code!\n", "\n", "" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "u_2z-B3piVsw" + "id": "u_2z-B3piVsw", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "# to generate gifs\n", + "# Install imgeio in order to generate an animated gif showing the image generating process\n", "!pip install imageio" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "e1_Y75QXJS6h" }, + "cell_type": "markdown", "source": [ - "## Import TensorFlow and enable eager execution" + "### Import TensorFlow and enable eager execution" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "YfIk2es3hJEd" + "id": "YfIk2es3hJEd", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ - "from __future__ import absolute_import, division, print_function\n", - "\n", - "# Import TensorFlow \u003e= 1.10 and enable eager execution\n", "import tensorflow as tf\n", "tf.enable_eager_execution()\n", "\n", - "import os\n", - "import time\n", - "import numpy as np\n", "import glob\n", - "import matplotlib.pyplot as plt\n", - "import PIL\n", "import imageio\n", + "import matplotlib.pyplot as plt\n", + "import numpy as np\n", + "import os\n", + "import PIL\n", + "import time\n", + "\n", "from IPython import display" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "iYn4MdZnKCey" }, + "cell_type": "markdown", "source": [ - "## Load the dataset\n", + "### Load the dataset\n", "\n", - "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will then generate handwritten digits." + "We are going to use the MNIST dataset to train the generator and the discriminator. The generator will generate handwritten digits resembling the MNIST data." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "a4fYMGxGhrna" + "id": "a4fYMGxGhrna", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "(train_images, train_labels), (_, _) = tf.keras.datasets.mnist.load_data()" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "NFC2ghIdiZYE" + "id": "NFC2ghIdiZYE", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "train_images = train_images.reshape(train_images.shape[0], 28, 28, 1).astype('float32')\n", - "# We are normalizing the images to the range of [-1, 1]\n", - "train_images = (train_images - 127.5) / 127.5" - ] + "train_images = (train_images - 127.5) / 127.5 # Normalize the images to [-1, 1]" + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "S4PIDhoDLbsZ" + "id": "S4PIDhoDLbsZ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "BUFFER_SIZE = 60000\n", "BATCH_SIZE = 256" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "PIGN6ouoQxt3" }, + "cell_type": "markdown", "source": [ - "## Use tf.data to create batches and shuffle the dataset" + "### Use tf.data to create batches and shuffle the dataset" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "-yKCCQOoJ7cn" + "id": "-yKCCQOoJ7cn", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "train_dataset = tf.data.Dataset.from_tensor_slices(train_images).shuffle(BUFFER_SIZE).batch(BATCH_SIZE)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "THY-sZMiQ4UV" }, - "source": [ - "## Write the generator and discriminator models\n", - "\n", - "* **Generator** \n", - " * It is responsible for **creating convincing images that are good enough to fool the discriminator**.\n", - " * It consists of Conv2DTranspose (Upsampling) layers. We start with a fully connected layer and upsample the image 2 times so as to reach the desired image size (mnist image size) which is (28, 28, 1). \n", - " * We use **leaky relu** activation except for the **last layer** which uses **tanh** activation.\n", - " \n", - "* **Discriminator**\n", - " * **The discriminator is responsible for classifying the fake images from the real images.**\n", - " * In other words, the discriminator is given generated images (from the generator) and the real MNIST images. The job of the discriminator is to classify these images into fake (generated) and real (MNIST images).\n", - " * **Basically the generator should be good enough to fool the discriminator that the generated images are real**." - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "VGLbvBEmjK0a" - }, - "outputs": [], - "source": [ - "class Generator(tf.keras.Model):\n", - " def __init__(self):\n", - " super(Generator, self).__init__()\n", - " self.fc1 = tf.keras.layers.Dense(7*7*64, use_bias=False)\n", - " self.batchnorm1 = tf.keras.layers.BatchNormalization()\n", - " \n", - " self.conv1 = tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(1, 1), padding='same', use_bias=False)\n", - " self.batchnorm2 = tf.keras.layers.BatchNormalization()\n", - " \n", - " self.conv2 = tf.keras.layers.Conv2DTranspose(32, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", - " self.batchnorm3 = tf.keras.layers.BatchNormalization()\n", - " \n", - " self.conv3 = tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False)\n", - "\n", - " def call(self, x, training=True):\n", - " x = self.fc1(x)\n", - " x = self.batchnorm1(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = tf.reshape(x, shape=(-1, 7, 7, 64))\n", - "\n", - " x = self.conv1(x)\n", - " x = self.batchnorm2(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = self.conv2(x)\n", - " x = self.batchnorm3(x, training=training)\n", - " x = tf.nn.relu(x)\n", - "\n", - " x = tf.nn.tanh(self.conv3(x)) \n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "bkOfJxk5j5Hi" - }, - "outputs": [], - "source": [ - "class Discriminator(tf.keras.Model):\n", - " def __init__(self):\n", - " super(Discriminator, self).__init__()\n", - " self.conv1 = tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same')\n", - " self.conv2 = tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same')\n", - " self.dropout = tf.keras.layers.Dropout(0.3)\n", - " self.flatten = tf.keras.layers.Flatten()\n", - " self.fc1 = tf.keras.layers.Dense(1)\n", - "\n", - " def call(self, x, training=True):\n", - " x = tf.nn.leaky_relu(self.conv1(x))\n", - " x = self.dropout(x, training=training)\n", - " x = tf.nn.leaky_relu(self.conv2(x))\n", - " x = self.dropout(x, training=training)\n", - " x = self.flatten(x)\n", - " x = self.fc1(x)\n", - " return x" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "gDkA05NE6QMs" - }, - "outputs": [], - "source": [ - "generator = Generator()\n", - "discriminator = Discriminator()" - ] - }, - { - "cell_type": "code", - "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "k1HpMSLImuRi" - }, - "outputs": [], - "source": [ - "# Defun gives 10 secs/epoch performance boost\n", - "generator.call = tf.contrib.eager.defun(generator.call)\n", - "discriminator.call = tf.contrib.eager.defun(discriminator.call)" - ] - }, - { "cell_type": "markdown", + "source": [ + "## Create the models\n", + "\n", + "We will use tf.keras [Sequential API](https://www.tensorflow.org/guide/keras#sequential_model) to define the generator and discriminator models." + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "-tEyxE-GMC48" + }, + "cell_type": "markdown", + "source": [ + "### The Generator Model\n", + "\n", + "The generator is responsible for creating convincing images that are good enough to fool the discriminator. The network architecture for the generator consists of [Conv2DTranspose](https://www.tensorflow.org/api_docs/python/tf/keras/layers/Conv2DTranspose) (Upsampling) layers. We start with a fully connected layer and upsample the image two times in order to reach the desired image size of 28x28x1. We increase the width and height, and reduce the depth as we move through the layers in the network. We use [Leaky ReLU](https://www.tensorflow.org/api_docs/python/tf/keras/layers/LeakyReLU) activation for each layer except for the last one where we use a tanh activation." + ] + }, + { + "metadata": { + "id": "6bpTcDqoLWjY", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def make_generator_model():\n", + " model = tf.keras.Sequential()\n", + " model.add(tf.keras.layers.Dense(7*7*256, use_bias=False, input_shape=(100,)))\n", + " model.add(tf.keras.layers.BatchNormalization())\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + " \n", + " model.add(tf.keras.layers.Reshape((7, 7, 256)))\n", + " assert model.output_shape == (None, 7, 7, 256) # Note: None is the batch size\n", + " \n", + " model.add(tf.keras.layers.Conv2DTranspose(128, (5, 5), strides=(1, 1), padding='same', use_bias=False))\n", + " assert model.output_shape == (None, 7, 7, 128) \n", + " model.add(tf.keras.layers.BatchNormalization())\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + "\n", + " model.add(tf.keras.layers.Conv2DTranspose(64, (5, 5), strides=(2, 2), padding='same', use_bias=False))\n", + " assert model.output_shape == (None, 14, 14, 64) \n", + " model.add(tf.keras.layers.BatchNormalization())\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + "\n", + " model.add(tf.keras.layers.Conv2DTranspose(1, (5, 5), strides=(2, 2), padding='same', use_bias=False, activation='tanh'))\n", + " assert model.output_shape == (None, 28, 28, 1)\n", + " \n", + " return model" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "D0IKnaCtg6WE" + }, + "cell_type": "markdown", + "source": [ + "### The Discriminator model\n", + "\n", + "The discriminator is responsible for distinguishing fake images from real images. It's similar to a regular CNN-based image classifier." + ] + }, + { + "metadata": { + "id": "dw2tPLmk2pEP", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def make_discriminator_model():\n", + " model = tf.keras.Sequential()\n", + " model.add(tf.keras.layers.Conv2D(64, (5, 5), strides=(2, 2), padding='same'))\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + " model.add(tf.keras.layers.Dropout(0.3))\n", + " \n", + " model.add(tf.keras.layers.Conv2D(128, (5, 5), strides=(2, 2), padding='same'))\n", + " model.add(tf.keras.layers.LeakyReLU())\n", + " model.add(tf.keras.layers.Dropout(0.3))\n", + " \n", + " model.add(tf.keras.layers.Flatten())\n", + " model.add(tf.keras.layers.Dense(1))\n", + " \n", + " return model" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "code", + "id": "gDkA05NE6QMs", + "colab": {} + }, + "cell_type": "code", + "source": [ + "generator = make_generator_model()\n", + "discriminator = make_discriminator_model()" + ], + "execution_count": 0, + "outputs": [] + }, + { "metadata": { "colab_type": "text", "id": "0FMYgY_mPfTi" }, + "cell_type": "markdown", "source": [ "## Define the loss functions and the optimizer\n", "\n", - "* **Discriminator loss**\n", - " * The discriminator loss function takes 2 inputs; **real images, generated images**\n", - " * real_loss is a sigmoid cross entropy loss of the **real images** and an **array of ones (since these are the real images)**\n", - " * generated_loss is a sigmoid cross entropy loss of the **generated images** and an **array of zeros (since these are the fake images)**\n", - " * Then the total_loss is the sum of real_loss and the generated_loss\n", - " \n", - "* **Generator loss**\n", - " * It is a sigmoid cross entropy loss of the generated images and an **array of ones**\n", - " \n", - "\n", - "* The discriminator and the generator optimizers are different since we will train them separately." + "Let's define the loss functions and the optimizers for the generator and the discriminator.\n" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "wkMNfBWlT-PV" + "colab_type": "text", + "id": "Jd-3GCUEiKtv" }, - "outputs": [], + "cell_type": "markdown", + "source": [ + "### Generator loss\n", + "The generator loss is a sigmoid cross entropy loss of the generated images and an array of ones, since the generator is trying to generate fake images that resemble the real images." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "90BIcCKcDMxz", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def generator_loss(generated_output):\n", + " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "PKY_iPSPNWoj" + }, + "cell_type": "markdown", + "source": [ + "### Discriminator loss\n", + "\n", + "The discriminator loss function takes two inputs: real images, and generated images. Here is how to calculate the discriminator loss:\n", + "1. Calculate real_loss which is a sigmoid cross entropy loss of the real images and an array of ones (since these are the real images).\n", + "2. Calculate generated_loss which is a sigmoid cross entropy loss of the generated images and an array of zeros (since these are the fake images).\n", + "3. Calculate the total_loss as the sum of real_loss and generated_loss." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "wkMNfBWlT-PV", + "colab": {} + }, + "cell_type": "code", "source": [ "def discriminator_loss(real_output, generated_output):\n", - " # [1,1,...,1] with real output since it is true and we want\n", - " # our generated examples to look like it\n", + " # [1,1,...,1] with real output since it is true and we want our generated examples to look like it\n", " real_loss = tf.losses.sigmoid_cross_entropy(multi_class_labels=tf.ones_like(real_output), logits=real_output)\n", "\n", " # [0,0,...,0] with generated images since they are fake\n", @@ -327,55 +408,51 @@ " total_loss = real_loss + generated_loss\n", "\n", " return total_loss" - ] - }, - { - "cell_type": "code", + ], "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "90BIcCKcDMxz" - }, - "outputs": [], - "source": [ - "def generator_loss(generated_output):\n", - " return tf.losses.sigmoid_cross_entropy(tf.ones_like(generated_output), generated_output)" - ] + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "iWCn_PVdEJZ7" + "colab_type": "text", + "id": "MgIc7i0th_Iu" }, - "outputs": [], - "source": [ - "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)\n", - "generator_optimizer = tf.train.AdamOptimizer(1e-4)" - ] - }, - { "cell_type": "markdown", + "source": [ + "The discriminator and the generator optimizers are different since we will train two networks separately." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "iWCn_PVdEJZ7", + "colab": {} + }, + "cell_type": "code", + "source": [ + "generator_optimizer = tf.train.AdamOptimizer(1e-4)\n", + "discriminator_optimizer = tf.train.AdamOptimizer(1e-4)" + ], + "execution_count": 0, + "outputs": [] + }, + { "metadata": { "colab_type": "text", "id": "mWtinsGDPJlV" }, + "cell_type": "markdown", "source": [ - "## Checkpoints (Object-based saving)" + "**Checkpoints (Object-based saving)**" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "CA1w-7s2POEy" + "id": "CA1w-7s2POEy", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "checkpoint_dir = './training_checkpoints'\n", "checkpoint_prefix = os.path.join(checkpoint_dir, \"ckpt\")\n", @@ -383,60 +460,184 @@ " discriminator_optimizer=discriminator_optimizer,\n", " generator=generator,\n", " discriminator=discriminator)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "Rw1fkAczTQYh" }, + "cell_type": "markdown", "source": [ - "## Training\n", - "\n", - "* We start by iterating over the dataset\n", - "* The generator is given **noise as an input** which when passed through the generator model will output a image looking like a handwritten digit\n", - "* The discriminator is given the **real MNIST images as well as the generated images (from the generator)**.\n", - "* Next, we calculate the generator and the discriminator loss.\n", - "* Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables (inputs) and apply those to the optimizer.\n", - "\n", - "## Generate Images\n", - "\n", - "* After training, its time to generate some images!\n", - "* We start by creating noise array as an input to the generator\n", - "* The generator will then convert the noise into handwritten images.\n", - "* Last step is to plot the predictions and **voila!**" + "## Set up GANs for Training\n", + "\n" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "NS2GWywBbAWo" + "colab_type": "text", + "id": "5QC5BABamh_c" }, - "outputs": [], + "cell_type": "markdown", "source": [ - "EPOCHS = 150\n", + "Now it's time to put together the generator and discriminator to set up the Generative Adversarial Networks, as you see in the diagam at the beginning of the tutorial." + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "Ff6oN6PZX27n" + }, + "cell_type": "markdown", + "source": [ + "**Define training parameters**" + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "NS2GWywBbAWo", + "colab": {} + }, + "cell_type": "code", + "source": [ + "EPOCHS = 50\n", "noise_dim = 100\n", "num_examples_to_generate = 16\n", "\n", - "# keeping the random vector constant for generation (prediction) so\n", - "# it will be easier to see the improvement of the gan.\n", + "# We'll re-use this random vector used to seed the generator so\n", + "# it will be easier to see the improvement over time.\n", "random_vector_for_generation = tf.random_normal([num_examples_to_generate,\n", " noise_dim])" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "jylSonrqSWfi" + }, + "cell_type": "markdown", + "source": [ + "**Define training method**\n", + "\n", + "We start by iterating over the dataset. The generator is given a random vector as an input which is processed to output an image looking like a handwritten digit. The discriminator is then shown the real MNIST images as well as the generated images.\n", + "\n", + "Next, we calculate the generator and the discriminator loss. Then, we calculate the gradients of loss with respect to both the generator and the discriminator variables." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, + "id": "3t5ibNo05jCB", "colab_type": "code", - "id": "RmdVsmvhPxyy" + "colab": {} }, - "outputs": [], + "cell_type": "code", + "source": [ + "def train_step(images):\n", + " # generating noise from a normal distribution\n", + " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", + " \n", + " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", + " generated_images = generator(noise, training=True)\n", + " \n", + " real_output = discriminator(images, training=True)\n", + " generated_output = discriminator(generated_images, training=True)\n", + " \n", + " gen_loss = generator_loss(generated_output)\n", + " disc_loss = discriminator_loss(real_output, generated_output)\n", + " \n", + " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", + " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", + " \n", + " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", + " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "6TSZgwc2BUQ-" + }, + "cell_type": "markdown", + "source": [ + "\n", + "This model takes about ~30 seconds per epoch to train on a single Tesla K80 on Colab, as of October 2018. \n", + "\n", + "Eager execution can be slower than executing the equivalent graph as it can't benefit from whole-program optimizations on the graph, and also incurs overheads of interpreting Python code. By using [tf.contrib.eager.defun](https://www.tensorflow.org/api_docs/python/tf/contrib/eager/defun) to create graph functions, we get a ~20 secs/epoch performance boost (from ~50 secs/epoch down to ~30 secs/epoch). This way we get the best of both eager execution (easier for debugging) and graph mode (better performance)." + ] + }, + { + "metadata": { + "id": "Iwya07_j5p2A", + "colab_type": "code", + "colab": {} + }, + "cell_type": "code", + "source": [ + "train_step = tf.contrib.eager.defun(train_step)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "code", + "id": "2M7LmLtGEMQJ", + "colab": {} + }, + "cell_type": "code", + "source": [ + "def train(dataset, epochs): \n", + " for epoch in range(epochs):\n", + " start = time.time()\n", + " \n", + " for images in dataset:\n", + " train_step(images)\n", + "\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", + " epoch + 1,\n", + " random_vector_for_generation)\n", + " \n", + " # saving (checkpoint) the model every 15 epochs\n", + " if (epoch + 1) % 15 == 0:\n", + " checkpoint.save(file_prefix = checkpoint_prefix)\n", + " \n", + " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", + " time.time()-start))\n", + " # generating after the final epoch\n", + " display.clear_output(wait=True)\n", + " generate_and_save_images(generator,\n", + " epochs,\n", + " random_vector_for_generation)" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "2aFF7Hk3XdeW" + }, + "cell_type": "markdown", + "source": [ + "**Generate and save images**\n", + "\n" + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "RmdVsmvhPxyy", + "colab": {} + }, + "cell_type": "code", "source": [ "def generate_and_save_images(model, epoch, test_input):\n", " # make sure the training parameter is set to False because we\n", @@ -452,164 +653,130 @@ " \n", " plt.savefig('image_at_epoch_{:04d}.png'.format(epoch))\n", " plt.show()" - ] - }, - { - "cell_type": "code", + ], "execution_count": 0, - "metadata": { - "colab": {}, - "colab_type": "code", - "id": "2M7LmLtGEMQJ" - }, - "outputs": [], - "source": [ - "def train(dataset, epochs, noise_dim): \n", - " for epoch in range(epochs):\n", - " start = time.time()\n", - " \n", - " for images in dataset:\n", - " # generating noise from a uniform distribution\n", - " noise = tf.random_normal([BATCH_SIZE, noise_dim])\n", - " \n", - " with tf.GradientTape() as gen_tape, tf.GradientTape() as disc_tape:\n", - " generated_images = generator(noise, training=True)\n", - " \n", - " real_output = discriminator(images, training=True)\n", - " generated_output = discriminator(generated_images, training=True)\n", - " \n", - " gen_loss = generator_loss(generated_output)\n", - " disc_loss = discriminator_loss(real_output, generated_output)\n", - " \n", - " gradients_of_generator = gen_tape.gradient(gen_loss, generator.variables)\n", - " gradients_of_discriminator = disc_tape.gradient(disc_loss, discriminator.variables)\n", - " \n", - " generator_optimizer.apply_gradients(zip(gradients_of_generator, generator.variables))\n", - " discriminator_optimizer.apply_gradients(zip(gradients_of_discriminator, discriminator.variables))\n", - "\n", - " \n", - " if epoch % 1 == 0:\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epoch + 1,\n", - " random_vector_for_generation)\n", - " \n", - " # saving (checkpoint) the model every 15 epochs\n", - " if (epoch + 1) % 15 == 0:\n", - " checkpoint.save(file_prefix = checkpoint_prefix)\n", - " \n", - " print ('Time taken for epoch {} is {} sec'.format(epoch + 1,\n", - " time.time()-start))\n", - " # generating after the final epoch\n", - " display.clear_output(wait=True)\n", - " generate_and_save_images(generator,\n", - " epochs,\n", - " random_vector_for_generation)" - ] + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "Ly3UN0SLLY2l" + "colab_type": "text", + "id": "dZrd4CdjR-Fp" }, - "outputs": [], - "source": [ - "train(train_dataset, EPOCHS, noise_dim)" - ] - }, - { "cell_type": "markdown", + "source": [ + "## Train the GANs\n", + "We will call the train() method defined above to train the generator and discriminator simultaneously. Note, training GANs can be tricky. It's important that the generator and discriminator do not overpower each other (e.g., that they train at a similar rate).\n", + "\n", + "At the beginning of the training, the generated images look like random noise. As training progresses, you can see the generated digits look increasingly real. After 50 epochs, they look very much like the MNIST digits." + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "Ly3UN0SLLY2l", + "colab": {} + }, + "cell_type": "code", + "source": [ + "%%time\n", + "train(train_dataset, EPOCHS)" + ], + "execution_count": 0, + "outputs": [] + }, + { "metadata": { "colab_type": "text", "id": "rfM4YcPVPkNO" }, + "cell_type": "markdown", "source": [ - "## Restore the latest checkpoint" + "**Restore the latest checkpoint**" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "XhXsd0srPo8c" + "id": "XhXsd0srPo8c", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "# restoring the latest checkpoint in checkpoint_dir\n", "checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "P4M_vIbUi7c0" }, + "cell_type": "markdown", "source": [ - "## Display an image using the epoch number" + "## Generated images \n" ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, - "colab_type": "code", - "id": "WfO5wCdclHGL" + "colab_type": "text", + "id": "mLskt7EfXAjr" }, - "outputs": [], + "cell_type": "markdown", "source": [ + "\n", + "After training, its time to generate some images! \n", + "The last step is to plot the generated images and voila!\n" + ] + }, + { + "metadata": { + "colab_type": "code", + "id": "WfO5wCdclHGL", + "colab": {} + }, + "cell_type": "code", + "source": [ + "# Display a single image using the epoch number\n", "def display_image(epoch_no):\n", " return PIL.Image.open('image_at_epoch_{:04d}.png'.format(epoch_no))" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "5x3q9_Oe5q0A" + "id": "5x3q9_Oe5q0A", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "display_image(EPOCHS)" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "NywiH3nL8guF" }, - "source": [ - "## Generate a GIF of all the saved images." - ] - }, - { "cell_type": "markdown", - "metadata": { - "colab_type": "text", - "id": "xmO0Dmu2WICn" - }, "source": [ - "\u003c!-- TODO(markdaoust): Remove the hack when Ipython version is updated --\u003e\n" + "**Generate a GIF of all the saved images**\n", + "\n", + "We will use imageio to create an animated gif using all the images saved during training." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "IGKQgENQ8lEI" + "id": "IGKQgENQ8lEI", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "with imageio.get_writer('dcgan.gif', mode='I') as writer:\n", " filenames = glob.glob('image*.png')\n", @@ -617,7 +784,7 @@ " last = -1\n", " for i,filename in enumerate(filenames):\n", " frame = 2*(i**0.5)\n", - " if round(frame) \u003e round(last):\n", + " if round(frame) > round(last):\n", " last = frame\n", " else:\n", " continue\n", @@ -628,67 +795,84 @@ " \n", "# this is a hack to display the gif inside the notebook\n", "os.system('cp dcgan.gif dcgan.gif.png')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "cGhC3-fMWSwl" + }, + "cell_type": "markdown", + "source": [ + "Display the animated gif with all the mages generated during the training of GANs." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "uV0yiKpzNP1b" + "id": "uV0yiKpzNP1b", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "display.Image(filename=\"dcgan.gif.png\")" - ] + ], + "execution_count": 0, + "outputs": [] }, { - "cell_type": "markdown", "metadata": { "colab_type": "text", "id": "6EEG-wePkmJQ" }, + "cell_type": "markdown", "source": [ - "To downlod the animation from Colab uncomment the code below:" + "**Download the animated gif**\n", + "\n", + "Uncomment the code below to download an animated gif from Colab." ] }, { - "cell_type": "code", - "execution_count": 0, "metadata": { - "colab": {}, "colab_type": "code", - "id": "4UJjSnIMOzOJ" + "id": "4UJjSnIMOzOJ", + "colab": {} }, - "outputs": [], + "cell_type": "code", "source": [ "#from google.colab import files\n", "#files.download('dcgan.gif')" + ], + "execution_count": 0, + "outputs": [] + }, + { + "metadata": { + "colab_type": "text", + "id": "k6qC-SbjK0yW" + }, + "cell_type": "markdown", + "source": [ + "## Learn more about GANs\n" + ] + }, + { + "metadata": { + "colab_type": "text", + "id": "xjjkT9KAK6H7" + }, + "cell_type": "markdown", + "source": [ + "We hope this tutorial was helpful! As a next step, you might like to experiment with a different dataset, for example the Large-scale Celeb Faces Attributes (CelebA) dataset [available on Kaggle](https://www.kaggle.com/jessicali9530/celeba-dataset/home).\n", + "\n", + "To learn more about GANs:\n", + "\n", + "* Check out MIT's lecture (linked above), or [this](http://cs231n.stanford.edu/slides/2018/cs231n_2018_lecture12.pdf) lecture form Stanford's CS231n. \n", + "\n", + "* We also recommend the [CVPR 2018 Tutorial on GANs](https://sites.google.com/view/cvpr2018tutorialongans/), and the [NIPS 2016 Tutorial: Generative Adversarial Networks](https://arxiv.org/abs/1701.00160).\n" ] } - ], - "metadata": { - "accelerator": "GPU", - "colab": { - "collapsed_sections": [], - "name": "dcgan.ipynb", - "private_outputs": true, - "provenance": [ - { - "file_id": "1eb0NOTQapkYs3X0v-zL1x5_LFKgDISnp", - "timestamp": 1527173385672 - } - ], - "toc_visible": true, - "version": "0.3.2" - }, - "kernelspec": { - "display_name": "Python 3", - "language": "python", - "name": "python3" - } - }, - "nbformat": 4, - "nbformat_minor": 0 -} + ] +} \ No newline at end of file diff --git a/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png new file mode 100644 index 00000000000..b715bd83ef1 Binary files /dev/null and b/tensorflow/contrib/eager/python/examples/generative_examples/gans_diagram.png differ