Standardize some names from TFGAN -> TF-GAN.

A noop.

PiperOrigin-RevId: 227273663
This commit is contained in:
A. Unique TensorFlower 2019-12-30 02:54:20 -08:00 committed by TensorFlower Gardener
parent bbc8c15b85
commit c985bd0dce
6 changed files with 27 additions and 25 deletions

View File

@ -1,4 +1,4 @@
# Files for using TFGAN framework.
# Files for using TF-GAN framework.
load("//tensorflow:tensorflow.bzl", "py_test")
package(default_visibility = [

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A TFGAN-backed GAN Estimator."""
"""A TF-GAN-backed GAN Estimator."""
from __future__ import absolute_import
from __future__ import division
@ -56,10 +56,10 @@ _summary_type_map = {
class GANEstimator(estimator.Estimator):
"""An estimator for Generative Adversarial Networks (GANs).
This Estimator is backed by TFGAN. The network functions follow the TFGAN API
except for one exception: if either `generator_fn` or `discriminator_fn` have
an argument called `mode`, then the tf.Estimator mode is passed in for that
argument. This helps with operations like batch normalization, which have
This Estimator is backed by TF-GAN. The network functions follow the TF-GAN
API except for one exception: if either `generator_fn` or `discriminator_fn`
have an argument called `mode`, then the tf.Estimator mode is passed in for
that argument. This helps with operations like batch normalization, which have
different train and evaluation behavior.
Example:
@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator):
import tensorflow as tf
tfgan = tf.contrib.gan
# See TFGAN's `train.py` for a description of the generator and
# See TF-GAN's `train.py` for a description of the generator and
# discriminator API.
def generator_fn(generator_inputs):
...
@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator):
to continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TFGAN` for more details and examples. Additionally, if
generator. See `TF-GAN` for more details and examples. Additionally, if
it has an argument called `mode`, the Estimator's `mode` will be passed
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
normalization.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details
and examples.
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
tuple.

View File

@ -28,7 +28,7 @@ wasserstein_gradient_penalty
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
patchGAN style losses (https://arxiv.org/abs/1611.07004).
To make these losses usable in the TFGAN framework, please create a tuple
To make these losses usable in the TF-GAN framework, please create a tuple
version of the losses with `losses_utils.py`.
"""
@ -320,7 +320,7 @@ def wasserstein_gradient_penalty(
generated_data: Output of the generator.
generator_inputs: Exact argument to pass to the generator, which is used
as optional conditioning to the discriminator.
discriminator_fn: A discriminator function that conforms to TFGAN API.
discriminator_fn: A discriminator function that conforms to TF-GAN API.
discriminator_scope: If not `None`, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when
computing the gradient norm.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
"""TF-GAN utilities for loss functions that accept GANModel namedtuples.
The losses and penalties in this file all correspond to losses in
`losses_impl.py`. Losses in that file take individual arguments, whereas in this

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Named tuples for TFGAN.
"""Named tuples for TF-GAN.
TFGAN training occurs in four steps, and each step communicates with the next
step via one of these named tuples. At each step, you can either use a TFGAN
TF-GAN training occurs in four steps, and each step communicates with the next
step via one of these named tuples. At each step, you can either use a TF-GAN
helper function in `train.py`, or you can manually construct a tuple.
"""

View File

@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework.
"""The TF-GAN project provides a lightweight GAN training/testing framework.
This file contains the core helper functions to create and train a GAN model.
See the README or examples in `tensorflow_models` for details on how to use.
TFGAN training occurs in four steps:
TF-GAN training occurs in four steps:
1) Create a model
2) Add a loss
3) Create train ops
@ -645,9 +645,10 @@ def gan_loss(
type(model))
# Optionally create pooled model.
pooled_model = (
_tensor_pool_adjusted_model(model, tensor_pool_fn)
if tensor_pool_fn else model)
if tensor_pool_fn:
pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn)
else:
pooled_model = model
# Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
@ -665,10 +666,11 @@ def gan_loss(
if _use_aux_loss(mutual_information_penalty_weight):
gen_info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries)
dis_info_loss = (
gen_info_loss
if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty(
pooled_model, add_summaries=add_summaries))
if tensor_pool_fn is None:
dis_info_loss = gen_info_loss
else:
dis_info_loss = tfgan_losses.mutual_information_penalty(
pooled_model, add_summaries=add_summaries)
gen_loss += mutual_information_penalty_weight * gen_info_loss
dis_loss += mutual_information_penalty_weight * dis_info_loss
if _use_aux_loss(aux_cond_generator_weight):
@ -929,7 +931,7 @@ def gan_train_ops(
**kwargs):
"""Returns GAN train ops.
The highest-level call in TFGAN. It is composed of functions that can also
The highest-level call in TF-GAN. It is composed of functions that can also
be called, should a user require more control over some part of the GAN
training process.