Standardize some names from TFGAN
-> TF-GAN
.
A noop. PiperOrigin-RevId: 227273663
This commit is contained in:
parent
bbc8c15b85
commit
c985bd0dce
@ -1,4 +1,4 @@
|
||||
# Files for using TFGAN framework.
|
||||
# Files for using TF-GAN framework.
|
||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||
|
||||
package(default_visibility = [
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A TFGAN-backed GAN Estimator."""
|
||||
"""A TF-GAN-backed GAN Estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
@ -56,10 +56,10 @@ _summary_type_map = {
|
||||
class GANEstimator(estimator.Estimator):
|
||||
"""An estimator for Generative Adversarial Networks (GANs).
|
||||
|
||||
This Estimator is backed by TFGAN. The network functions follow the TFGAN API
|
||||
except for one exception: if either `generator_fn` or `discriminator_fn` have
|
||||
an argument called `mode`, then the tf.Estimator mode is passed in for that
|
||||
argument. This helps with operations like batch normalization, which have
|
||||
This Estimator is backed by TF-GAN. The network functions follow the TF-GAN
|
||||
API except for one exception: if either `generator_fn` or `discriminator_fn`
|
||||
have an argument called `mode`, then the tf.Estimator mode is passed in for
|
||||
that argument. This helps with operations like batch normalization, which have
|
||||
different train and evaluation behavior.
|
||||
|
||||
Example:
|
||||
@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator):
|
||||
import tensorflow as tf
|
||||
tfgan = tf.contrib.gan
|
||||
|
||||
# See TFGAN's `train.py` for a description of the generator and
|
||||
# See TF-GAN's `train.py` for a description of the generator and
|
||||
# discriminator API.
|
||||
def generator_fn(generator_inputs):
|
||||
...
|
||||
@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator):
|
||||
to continue training a previously saved model.
|
||||
generator_fn: A python function that takes a Tensor, Tensor list, or
|
||||
Tensor dictionary as inputs and returns the outputs of the GAN
|
||||
generator. See `TFGAN` for more details and examples. Additionally, if
|
||||
generator. See `TF-GAN` for more details and examples. Additionally, if
|
||||
it has an argument called `mode`, the Estimator's `mode` will be passed
|
||||
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
|
||||
normalization.
|
||||
discriminator_fn: A python function that takes the output of
|
||||
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
|
||||
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
|
||||
Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details
|
||||
and examples.
|
||||
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
|
||||
tuple.
|
||||
|
@ -28,7 +28,7 @@ wasserstein_gradient_penalty
|
||||
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
|
||||
patchGAN style losses (https://arxiv.org/abs/1611.07004).
|
||||
|
||||
To make these losses usable in the TFGAN framework, please create a tuple
|
||||
To make these losses usable in the TF-GAN framework, please create a tuple
|
||||
version of the losses with `losses_utils.py`.
|
||||
"""
|
||||
|
||||
@ -320,7 +320,7 @@ def wasserstein_gradient_penalty(
|
||||
generated_data: Output of the generator.
|
||||
generator_inputs: Exact argument to pass to the generator, which is used
|
||||
as optional conditioning to the discriminator.
|
||||
discriminator_fn: A discriminator function that conforms to TFGAN API.
|
||||
discriminator_fn: A discriminator function that conforms to TF-GAN API.
|
||||
discriminator_scope: If not `None`, reuse discriminators from this scope.
|
||||
epsilon: A small positive number added for numerical stability when
|
||||
computing the gradient norm.
|
||||
|
@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
|
||||
"""TF-GAN utilities for loss functions that accept GANModel namedtuples.
|
||||
|
||||
The losses and penalties in this file all correspond to losses in
|
||||
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
|
||||
|
@ -12,10 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Named tuples for TFGAN.
|
||||
"""Named tuples for TF-GAN.
|
||||
|
||||
TFGAN training occurs in four steps, and each step communicates with the next
|
||||
step via one of these named tuples. At each step, you can either use a TFGAN
|
||||
TF-GAN training occurs in four steps, and each step communicates with the next
|
||||
step via one of these named tuples. At each step, you can either use a TF-GAN
|
||||
helper function in `train.py`, or you can manually construct a tuple.
|
||||
"""
|
||||
|
||||
|
@ -12,12 +12,12 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The TFGAN project provides a lightweight GAN training/testing framework.
|
||||
"""The TF-GAN project provides a lightweight GAN training/testing framework.
|
||||
|
||||
This file contains the core helper functions to create and train a GAN model.
|
||||
See the README or examples in `tensorflow_models` for details on how to use.
|
||||
|
||||
TFGAN training occurs in four steps:
|
||||
TF-GAN training occurs in four steps:
|
||||
1) Create a model
|
||||
2) Add a loss
|
||||
3) Create train ops
|
||||
@ -645,9 +645,10 @@ def gan_loss(
|
||||
type(model))
|
||||
|
||||
# Optionally create pooled model.
|
||||
pooled_model = (
|
||||
_tensor_pool_adjusted_model(model, tensor_pool_fn)
|
||||
if tensor_pool_fn else model)
|
||||
if tensor_pool_fn:
|
||||
pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn)
|
||||
else:
|
||||
pooled_model = model
|
||||
|
||||
# Create standard losses.
|
||||
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
|
||||
@ -665,10 +666,11 @@ def gan_loss(
|
||||
if _use_aux_loss(mutual_information_penalty_weight):
|
||||
gen_info_loss = tfgan_losses.mutual_information_penalty(
|
||||
model, add_summaries=add_summaries)
|
||||
dis_info_loss = (
|
||||
gen_info_loss
|
||||
if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty(
|
||||
pooled_model, add_summaries=add_summaries))
|
||||
if tensor_pool_fn is None:
|
||||
dis_info_loss = gen_info_loss
|
||||
else:
|
||||
dis_info_loss = tfgan_losses.mutual_information_penalty(
|
||||
pooled_model, add_summaries=add_summaries)
|
||||
gen_loss += mutual_information_penalty_weight * gen_info_loss
|
||||
dis_loss += mutual_information_penalty_weight * dis_info_loss
|
||||
if _use_aux_loss(aux_cond_generator_weight):
|
||||
@ -929,7 +931,7 @@ def gan_train_ops(
|
||||
**kwargs):
|
||||
"""Returns GAN train ops.
|
||||
|
||||
The highest-level call in TFGAN. It is composed of functions that can also
|
||||
The highest-level call in TF-GAN. It is composed of functions that can also
|
||||
be called, should a user require more control over some part of the GAN
|
||||
training process.
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user