diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index ae8320cfb27..e1d198fe7bc 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -1,4 +1,4 @@ -# Files for using TFGAN framework. +# Files for using TF-GAN framework. load("//tensorflow:tensorflow.bzl", "py_test") package(default_visibility = [ diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py index adb72228217..dd904611d1a 100644 --- a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""A TFGAN-backed GAN Estimator.""" +"""A TF-GAN-backed GAN Estimator.""" from __future__ import absolute_import from __future__ import division @@ -56,10 +56,10 @@ _summary_type_map = { class GANEstimator(estimator.Estimator): """An estimator for Generative Adversarial Networks (GANs). - This Estimator is backed by TFGAN. The network functions follow the TFGAN API - except for one exception: if either `generator_fn` or `discriminator_fn` have - an argument called `mode`, then the tf.Estimator mode is passed in for that - argument. This helps with operations like batch normalization, which have + This Estimator is backed by TF-GAN. The network functions follow the TF-GAN + API except for one exception: if either `generator_fn` or `discriminator_fn` + have an argument called `mode`, then the tf.Estimator mode is passed in for + that argument. This helps with operations like batch normalization, which have different train and evaluation behavior. Example: @@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator): import tensorflow as tf tfgan = tf.contrib.gan - # See TFGAN's `train.py` for a description of the generator and + # See TF-GAN's `train.py` for a description of the generator and # discriminator API. def generator_fn(generator_inputs): ... @@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator): to continue training a previously saved model. generator_fn: A python function that takes a Tensor, Tensor list, or Tensor dictionary as inputs and returns the outputs of the GAN - generator. See `TFGAN` for more details and examples. Additionally, if + generator. See `TF-GAN` for more details and examples. Additionally, if it has an argument called `mode`, the Estimator's `mode` will be passed in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch normalization. discriminator_fn: A python function that takes the output of `generator_fn` or real data in the GAN setup, and `generator_inputs`. - Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details and examples. generator_loss_fn: The loss function on the generator. Takes a `GANModel` tuple. diff --git a/tensorflow/contrib/gan/python/losses/python/losses_impl.py b/tensorflow/contrib/gan/python/losses/python/losses_impl.py index a0a86c6337e..930c91786f5 100644 --- a/tensorflow/contrib/gan/python/losses/python/losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/losses_impl.py @@ -28,7 +28,7 @@ wasserstein_gradient_penalty All losses must be able to accept 1D or 2D Tensors, so as to be compatible with patchGAN style losses (https://arxiv.org/abs/1611.07004). -To make these losses usable in the TFGAN framework, please create a tuple +To make these losses usable in the TF-GAN framework, please create a tuple version of the losses with `losses_utils.py`. """ @@ -320,7 +320,7 @@ def wasserstein_gradient_penalty( generated_data: Output of the generator. generator_inputs: Exact argument to pass to the generator, which is used as optional conditioning to the discriminator. - discriminator_fn: A discriminator function that conforms to TFGAN API. + discriminator_fn: A discriminator function that conforms to TF-GAN API. discriminator_scope: If not `None`, reuse discriminators from this scope. epsilon: A small positive number added for numerical stability when computing the gradient norm. diff --git a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py index 221c70c38bd..76e57df7f64 100644 --- a/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py +++ b/tensorflow/contrib/gan/python/losses/python/tuple_losses_impl.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""TFGAN utilities for loss functions that accept GANModel namedtuples. +"""TF-GAN utilities for loss functions that accept GANModel namedtuples. The losses and penalties in this file all correspond to losses in `losses_impl.py`. Losses in that file take individual arguments, whereas in this diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py index 969b68449d9..73dfee4fdee 100644 --- a/tensorflow/contrib/gan/python/namedtuples.py +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Named tuples for TFGAN. +"""Named tuples for TF-GAN. -TFGAN training occurs in four steps, and each step communicates with the next -step via one of these named tuples. At each step, you can either use a TFGAN +TF-GAN training occurs in four steps, and each step communicates with the next +step via one of these named tuples. At each step, you can either use a TF-GAN helper function in `train.py`, or you can manually construct a tuple. """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py index 4c7bee41b33..f36a5d346e0 100644 --- a/tensorflow/contrib/gan/python/train.py +++ b/tensorflow/contrib/gan/python/train.py @@ -12,12 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""The TFGAN project provides a lightweight GAN training/testing framework. +"""The TF-GAN project provides a lightweight GAN training/testing framework. This file contains the core helper functions to create and train a GAN model. See the README or examples in `tensorflow_models` for details on how to use. -TFGAN training occurs in four steps: +TF-GAN training occurs in four steps: 1) Create a model 2) Add a loss 3) Create train ops @@ -645,9 +645,10 @@ def gan_loss( type(model)) # Optionally create pooled model. - pooled_model = ( - _tensor_pool_adjusted_model(model, tensor_pool_fn) - if tensor_pool_fn else model) + if tensor_pool_fn: + pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) + else: + pooled_model = model # Create standard losses. gen_loss = generator_loss_fn(model, add_summaries=add_summaries) @@ -665,10 +666,11 @@ def gan_loss( if _use_aux_loss(mutual_information_penalty_weight): gen_info_loss = tfgan_losses.mutual_information_penalty( model, add_summaries=add_summaries) - dis_info_loss = ( - gen_info_loss - if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( - pooled_model, add_summaries=add_summaries)) + if tensor_pool_fn is None: + dis_info_loss = gen_info_loss + else: + dis_info_loss = tfgan_losses.mutual_information_penalty( + pooled_model, add_summaries=add_summaries) gen_loss += mutual_information_penalty_weight * gen_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss if _use_aux_loss(aux_cond_generator_weight): @@ -929,7 +931,7 @@ def gan_train_ops( **kwargs): """Returns GAN train ops. - The highest-level call in TFGAN. It is composed of functions that can also + The highest-level call in TF-GAN. It is composed of functions that can also be called, should a user require more control over some part of the GAN training process.