Standardize some names from TFGAN -> TF-GAN.

A noop.

PiperOrigin-RevId: 227273663
This commit is contained in:
A. Unique TensorFlower 2019-12-30 02:54:20 -08:00 committed by TensorFlower Gardener
parent bbc8c15b85
commit c985bd0dce
6 changed files with 27 additions and 25 deletions

View File

@ -1,4 +1,4 @@
# Files for using TFGAN framework. # Files for using TF-GAN framework.
load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow:tensorflow.bzl", "py_test")
package(default_visibility = [ package(default_visibility = [

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""A TFGAN-backed GAN Estimator.""" """A TF-GAN-backed GAN Estimator."""
from __future__ import absolute_import from __future__ import absolute_import
from __future__ import division from __future__ import division
@ -56,10 +56,10 @@ _summary_type_map = {
class GANEstimator(estimator.Estimator): class GANEstimator(estimator.Estimator):
"""An estimator for Generative Adversarial Networks (GANs). """An estimator for Generative Adversarial Networks (GANs).
This Estimator is backed by TFGAN. The network functions follow the TFGAN API This Estimator is backed by TF-GAN. The network functions follow the TF-GAN
except for one exception: if either `generator_fn` or `discriminator_fn` have API except for one exception: if either `generator_fn` or `discriminator_fn`
an argument called `mode`, then the tf.Estimator mode is passed in for that have an argument called `mode`, then the tf.Estimator mode is passed in for
argument. This helps with operations like batch normalization, which have that argument. This helps with operations like batch normalization, which have
different train and evaluation behavior. different train and evaluation behavior.
Example: Example:
@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator):
import tensorflow as tf import tensorflow as tf
tfgan = tf.contrib.gan tfgan = tf.contrib.gan
# See TFGAN's `train.py` for a description of the generator and # See TF-GAN's `train.py` for a description of the generator and
# discriminator API. # discriminator API.
def generator_fn(generator_inputs): def generator_fn(generator_inputs):
... ...
@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator):
to continue training a previously saved model. to continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TFGAN` for more details and examples. Additionally, if generator. See `TF-GAN` for more details and examples. Additionally, if
it has an argument called `mode`, the Estimator's `mode` will be passed it has an argument called `mode`, the Estimator's `mode` will be passed
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
normalization. normalization.
discriminator_fn: A python function that takes the output of discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`. `generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details
and examples. and examples.
generator_loss_fn: The loss function on the generator. Takes a `GANModel` generator_loss_fn: The loss function on the generator. Takes a `GANModel`
tuple. tuple.

View File

@ -28,7 +28,7 @@ wasserstein_gradient_penalty
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
patchGAN style losses (https://arxiv.org/abs/1611.07004). patchGAN style losses (https://arxiv.org/abs/1611.07004).
To make these losses usable in the TFGAN framework, please create a tuple To make these losses usable in the TF-GAN framework, please create a tuple
version of the losses with `losses_utils.py`. version of the losses with `losses_utils.py`.
""" """
@ -320,7 +320,7 @@ def wasserstein_gradient_penalty(
generated_data: Output of the generator. generated_data: Output of the generator.
generator_inputs: Exact argument to pass to the generator, which is used generator_inputs: Exact argument to pass to the generator, which is used
as optional conditioning to the discriminator. as optional conditioning to the discriminator.
discriminator_fn: A discriminator function that conforms to TFGAN API. discriminator_fn: A discriminator function that conforms to TF-GAN API.
discriminator_scope: If not `None`, reuse discriminators from this scope. discriminator_scope: If not `None`, reuse discriminators from this scope.
epsilon: A small positive number added for numerical stability when epsilon: A small positive number added for numerical stability when
computing the gradient norm. computing the gradient norm.

View File

@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""TFGAN utilities for loss functions that accept GANModel namedtuples. """TF-GAN utilities for loss functions that accept GANModel namedtuples.
The losses and penalties in this file all correspond to losses in The losses and penalties in this file all correspond to losses in
`losses_impl.py`. Losses in that file take individual arguments, whereas in this `losses_impl.py`. Losses in that file take individual arguments, whereas in this

View File

@ -12,10 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""Named tuples for TFGAN. """Named tuples for TF-GAN.
TFGAN training occurs in four steps, and each step communicates with the next TF-GAN training occurs in four steps, and each step communicates with the next
step via one of these named tuples. At each step, you can either use a TFGAN step via one of these named tuples. At each step, you can either use a TF-GAN
helper function in `train.py`, or you can manually construct a tuple. helper function in `train.py`, or you can manually construct a tuple.
""" """

View File

@ -12,12 +12,12 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# ============================================================================== # ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework. """The TF-GAN project provides a lightweight GAN training/testing framework.
This file contains the core helper functions to create and train a GAN model. This file contains the core helper functions to create and train a GAN model.
See the README or examples in `tensorflow_models` for details on how to use. See the README or examples in `tensorflow_models` for details on how to use.
TFGAN training occurs in four steps: TF-GAN training occurs in four steps:
1) Create a model 1) Create a model
2) Add a loss 2) Add a loss
3) Create train ops 3) Create train ops
@ -645,9 +645,10 @@ def gan_loss(
type(model)) type(model))
# Optionally create pooled model. # Optionally create pooled model.
pooled_model = ( if tensor_pool_fn:
_tensor_pool_adjusted_model(model, tensor_pool_fn) pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn)
if tensor_pool_fn else model) else:
pooled_model = model
# Create standard losses. # Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries) gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
@ -665,10 +666,11 @@ def gan_loss(
if _use_aux_loss(mutual_information_penalty_weight): if _use_aux_loss(mutual_information_penalty_weight):
gen_info_loss = tfgan_losses.mutual_information_penalty( gen_info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries) model, add_summaries=add_summaries)
dis_info_loss = ( if tensor_pool_fn is None:
gen_info_loss dis_info_loss = gen_info_loss
if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty( else:
pooled_model, add_summaries=add_summaries)) dis_info_loss = tfgan_losses.mutual_information_penalty(
pooled_model, add_summaries=add_summaries)
gen_loss += mutual_information_penalty_weight * gen_info_loss gen_loss += mutual_information_penalty_weight * gen_info_loss
dis_loss += mutual_information_penalty_weight * dis_info_loss dis_loss += mutual_information_penalty_weight * dis_info_loss
if _use_aux_loss(aux_cond_generator_weight): if _use_aux_loss(aux_cond_generator_weight):
@ -929,7 +931,7 @@ def gan_train_ops(
**kwargs): **kwargs):
"""Returns GAN train ops. """Returns GAN train ops.
The highest-level call in TFGAN. It is composed of functions that can also The highest-level call in TF-GAN. It is composed of functions that can also
be called, should a user require more control over some part of the GAN be called, should a user require more control over some part of the GAN
training process. training process.