Standardize some names from TFGAN
-> TF-GAN
.
A noop. PiperOrigin-RevId: 227273663
This commit is contained in:
parent
bbc8c15b85
commit
c985bd0dce
@ -1,4 +1,4 @@
|
|||||||
# Files for using TFGAN framework.
|
# Files for using TF-GAN framework.
|
||||||
load("//tensorflow:tensorflow.bzl", "py_test")
|
load("//tensorflow:tensorflow.bzl", "py_test")
|
||||||
|
|
||||||
package(default_visibility = [
|
package(default_visibility = [
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""A TFGAN-backed GAN Estimator."""
|
"""A TF-GAN-backed GAN Estimator."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -56,10 +56,10 @@ _summary_type_map = {
|
|||||||
class GANEstimator(estimator.Estimator):
|
class GANEstimator(estimator.Estimator):
|
||||||
"""An estimator for Generative Adversarial Networks (GANs).
|
"""An estimator for Generative Adversarial Networks (GANs).
|
||||||
|
|
||||||
This Estimator is backed by TFGAN. The network functions follow the TFGAN API
|
This Estimator is backed by TF-GAN. The network functions follow the TF-GAN
|
||||||
except for one exception: if either `generator_fn` or `discriminator_fn` have
|
API except for one exception: if either `generator_fn` or `discriminator_fn`
|
||||||
an argument called `mode`, then the tf.Estimator mode is passed in for that
|
have an argument called `mode`, then the tf.Estimator mode is passed in for
|
||||||
argument. This helps with operations like batch normalization, which have
|
that argument. This helps with operations like batch normalization, which have
|
||||||
different train and evaluation behavior.
|
different train and evaluation behavior.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
@ -68,7 +68,7 @@ class GANEstimator(estimator.Estimator):
|
|||||||
import tensorflow as tf
|
import tensorflow as tf
|
||||||
tfgan = tf.contrib.gan
|
tfgan = tf.contrib.gan
|
||||||
|
|
||||||
# See TFGAN's `train.py` for a description of the generator and
|
# See TF-GAN's `train.py` for a description of the generator and
|
||||||
# discriminator API.
|
# discriminator API.
|
||||||
def generator_fn(generator_inputs):
|
def generator_fn(generator_inputs):
|
||||||
...
|
...
|
||||||
@ -123,13 +123,13 @@ class GANEstimator(estimator.Estimator):
|
|||||||
to continue training a previously saved model.
|
to continue training a previously saved model.
|
||||||
generator_fn: A python function that takes a Tensor, Tensor list, or
|
generator_fn: A python function that takes a Tensor, Tensor list, or
|
||||||
Tensor dictionary as inputs and returns the outputs of the GAN
|
Tensor dictionary as inputs and returns the outputs of the GAN
|
||||||
generator. See `TFGAN` for more details and examples. Additionally, if
|
generator. See `TF-GAN` for more details and examples. Additionally, if
|
||||||
it has an argument called `mode`, the Estimator's `mode` will be passed
|
it has an argument called `mode`, the Estimator's `mode` will be passed
|
||||||
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
|
in (ex TRAIN, EVAL, PREDICT). This is useful for things like batch
|
||||||
normalization.
|
normalization.
|
||||||
discriminator_fn: A python function that takes the output of
|
discriminator_fn: A python function that takes the output of
|
||||||
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
|
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
|
||||||
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
|
Outputs a Tensor in the range [-inf, inf]. See `TF-GAN` for more details
|
||||||
and examples.
|
and examples.
|
||||||
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
|
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
|
||||||
tuple.
|
tuple.
|
||||||
|
@ -28,7 +28,7 @@ wasserstein_gradient_penalty
|
|||||||
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
|
All losses must be able to accept 1D or 2D Tensors, so as to be compatible with
|
||||||
patchGAN style losses (https://arxiv.org/abs/1611.07004).
|
patchGAN style losses (https://arxiv.org/abs/1611.07004).
|
||||||
|
|
||||||
To make these losses usable in the TFGAN framework, please create a tuple
|
To make these losses usable in the TF-GAN framework, please create a tuple
|
||||||
version of the losses with `losses_utils.py`.
|
version of the losses with `losses_utils.py`.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@ -320,7 +320,7 @@ def wasserstein_gradient_penalty(
|
|||||||
generated_data: Output of the generator.
|
generated_data: Output of the generator.
|
||||||
generator_inputs: Exact argument to pass to the generator, which is used
|
generator_inputs: Exact argument to pass to the generator, which is used
|
||||||
as optional conditioning to the discriminator.
|
as optional conditioning to the discriminator.
|
||||||
discriminator_fn: A discriminator function that conforms to TFGAN API.
|
discriminator_fn: A discriminator function that conforms to TF-GAN API.
|
||||||
discriminator_scope: If not `None`, reuse discriminators from this scope.
|
discriminator_scope: If not `None`, reuse discriminators from this scope.
|
||||||
epsilon: A small positive number added for numerical stability when
|
epsilon: A small positive number added for numerical stability when
|
||||||
computing the gradient norm.
|
computing the gradient norm.
|
||||||
|
@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""TFGAN utilities for loss functions that accept GANModel namedtuples.
|
"""TF-GAN utilities for loss functions that accept GANModel namedtuples.
|
||||||
|
|
||||||
The losses and penalties in this file all correspond to losses in
|
The losses and penalties in this file all correspond to losses in
|
||||||
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
|
`losses_impl.py`. Losses in that file take individual arguments, whereas in this
|
||||||
|
@ -12,10 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""Named tuples for TFGAN.
|
"""Named tuples for TF-GAN.
|
||||||
|
|
||||||
TFGAN training occurs in four steps, and each step communicates with the next
|
TF-GAN training occurs in four steps, and each step communicates with the next
|
||||||
step via one of these named tuples. At each step, you can either use a TFGAN
|
step via one of these named tuples. At each step, you can either use a TF-GAN
|
||||||
helper function in `train.py`, or you can manually construct a tuple.
|
helper function in `train.py`, or you can manually construct a tuple.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -12,12 +12,12 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
"""The TFGAN project provides a lightweight GAN training/testing framework.
|
"""The TF-GAN project provides a lightweight GAN training/testing framework.
|
||||||
|
|
||||||
This file contains the core helper functions to create and train a GAN model.
|
This file contains the core helper functions to create and train a GAN model.
|
||||||
See the README or examples in `tensorflow_models` for details on how to use.
|
See the README or examples in `tensorflow_models` for details on how to use.
|
||||||
|
|
||||||
TFGAN training occurs in four steps:
|
TF-GAN training occurs in four steps:
|
||||||
1) Create a model
|
1) Create a model
|
||||||
2) Add a loss
|
2) Add a loss
|
||||||
3) Create train ops
|
3) Create train ops
|
||||||
@ -645,9 +645,10 @@ def gan_loss(
|
|||||||
type(model))
|
type(model))
|
||||||
|
|
||||||
# Optionally create pooled model.
|
# Optionally create pooled model.
|
||||||
pooled_model = (
|
if tensor_pool_fn:
|
||||||
_tensor_pool_adjusted_model(model, tensor_pool_fn)
|
pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn)
|
||||||
if tensor_pool_fn else model)
|
else:
|
||||||
|
pooled_model = model
|
||||||
|
|
||||||
# Create standard losses.
|
# Create standard losses.
|
||||||
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
|
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
|
||||||
@ -665,10 +666,11 @@ def gan_loss(
|
|||||||
if _use_aux_loss(mutual_information_penalty_weight):
|
if _use_aux_loss(mutual_information_penalty_weight):
|
||||||
gen_info_loss = tfgan_losses.mutual_information_penalty(
|
gen_info_loss = tfgan_losses.mutual_information_penalty(
|
||||||
model, add_summaries=add_summaries)
|
model, add_summaries=add_summaries)
|
||||||
dis_info_loss = (
|
if tensor_pool_fn is None:
|
||||||
gen_info_loss
|
dis_info_loss = gen_info_loss
|
||||||
if tensor_pool_fn is None else tfgan_losses.mutual_information_penalty(
|
else:
|
||||||
pooled_model, add_summaries=add_summaries))
|
dis_info_loss = tfgan_losses.mutual_information_penalty(
|
||||||
|
pooled_model, add_summaries=add_summaries)
|
||||||
gen_loss += mutual_information_penalty_weight * gen_info_loss
|
gen_loss += mutual_information_penalty_weight * gen_info_loss
|
||||||
dis_loss += mutual_information_penalty_weight * dis_info_loss
|
dis_loss += mutual_information_penalty_weight * dis_info_loss
|
||||||
if _use_aux_loss(aux_cond_generator_weight):
|
if _use_aux_loss(aux_cond_generator_weight):
|
||||||
@ -929,7 +931,7 @@ def gan_train_ops(
|
|||||||
**kwargs):
|
**kwargs):
|
||||||
"""Returns GAN train ops.
|
"""Returns GAN train ops.
|
||||||
|
|
||||||
The highest-level call in TFGAN. It is composed of functions that can also
|
The highest-level call in TF-GAN. It is composed of functions that can also
|
||||||
be called, should a user require more control over some part of the GAN
|
be called, should a user require more control over some part of the GAN
|
||||||
training process.
|
training process.
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user