Migrate core TFGAN functions to opensource.

PiperOrigin-RevId: 168391923
This commit is contained in:
A. Unique TensorFlower 2017-09-12 10:09:54 -07:00 committed by TensorFlower Gardener
parent bc6b60f1bc
commit f63aa7f49f
5 changed files with 1769 additions and 4 deletions

View File

@ -16,6 +16,60 @@ py_library(
deps = [
":features",
":losses",
":namedtuples",
":train",
],
)
py_library(
name = "namedtuples",
srcs = ["python/namedtuples.py"],
srcs_version = "PY2AND3",
)
py_library(
name = "train",
srcs = ["python/train.py"],
srcs_version = "PY2AND3",
deps = [
":losses",
":namedtuples",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/slim:learning",
"//tensorflow/contrib/training:training_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:check_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:init_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/ops/distributions",
"//tensorflow/python/ops/losses",
],
)
py_test(
name = "train_test",
srcs = ["python/train_test.py"],
srcs_version = "PY2AND3",
deps = [
":namedtuples",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/contrib/slim:learning",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:constant_op",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:random_ops",
"//tensorflow/python:random_seed",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python:variables",
"//tensorflow/python/ops/distributions",
"//third_party/py/numpy",
],
)

View File

@ -1,4 +1,4 @@
# Copyright 2016 Google Inc. All Rights Reserved.
# Copyright 2017 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
@ -21,7 +21,20 @@ from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python import train
del absolute_import
del division
del print_function
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.namedtuples import *
from tensorflow.contrib.gan.python.train import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'features',
'losses',
]
_allowed_symbols += train.__all__
_allowed_symbols += namedtuples.__all__
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,149 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Named tuples for TFGAN."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
__all__ = [
'GANModel',
'InfoGANModel',
'ACGANModel',
'GANLoss',
'GANTrainOps',
'GANTrainSteps',
]
class GANModel(
collections.namedtuple('GANModel', (
'generator_inputs',
'generated_data',
'generator_variables',
'generator_scope',
'generator_fn',
'real_data',
'discriminator_real_outputs',
'discriminator_gen_outputs',
'discriminator_variables',
'discriminator_scope',
'discriminator_fn',
))):
"""A GANModel contains all the pieces needed for GAN training.
Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt
to create an implicit generative model of data by solving a two agent game.
The generator generates candidate examples that are supposed to match the
data distribution, and the discriminator aims to tell the real examples
apart from the generated samples.
Args:
generator_inputs: The random noise source that acts as input to the
generator.
generated_data: The generated output data of the GAN.
generator_variables: A list of all generator variables.
generator_scope: Variable scope all generator variables live in.
generator_fn: The generator function.
real_data: A tensor or real data.
discriminator_real_outputs: The discriminator's output on real data.
discriminator_gen_outputs: The discriminator's output on generated data.
discriminator_variables: A list of all discriminator variables.
discriminator_scope: Variable scope all discriminator variables live in.
discriminator_fn: The discriminator function.
"""
# TODO(joelshor): Have this class inherit from `GANModel`.
class InfoGANModel(
collections.namedtuple('InfoGANModel', GANModel._fields + (
'structured_generator_inputs',
'predicted_distributions',
))):
"""An InfoGANModel contains all the pieces needed for InfoGAN training.
See https://arxiv.org/abs/1606.03657 for more details.
Args:
structured_generator_inputs: A list of Tensors representing the random noise
that must have high mutual information with the generator output. List
length should match `predicted_distributions`.
predicted_distributions: A list of tf.Distributions. Predicted by the
recognizer, and used to evaluate the likelihood of the structured noise.
List length should match `structured_generator_inputs`.
"""
class ACGANModel(
collections.namedtuple('ACGANModel', GANModel._fields +
('one_hot_labels',
'discriminator_real_classification_logits',
'discriminator_gen_classification_logits',))):
"""An ACGANModel contains all the pieces needed for ACGAN training.
See https://arxiv.org/abs/1610.09585 for more details.
Args:
one_hot_labels: A Tensor holding one-hot-labels for the batch.
discriminator_real_classification_logits: Classification logits for real
data.
discriminator_gen_classification_logits: Classification logits for generated
data.
"""
class GANLoss(
collections.namedtuple('GANLoss', (
'generator_loss',
'discriminator_loss'
))):
"""GANLoss contains the generator and discriminator losses.
Args:
generator_loss: A tensor for the generator loss..
discriminator_loss: A tensor for the discriminator loss.
"""
class GANTrainOps(
collections.namedtuple('GANTrainOps', (
'generator_train_op',
'discriminator_train_op',
'global_step_inc_op'
))):
"""GANTrainOps contains the training ops.
Args:
generator_train_op: Op that performs a generator update step.
discriminator_train_op: Op that performs a discriminator update step.
global_step_inc_op: Op that increments the shared global step.
"""
class GANTrainSteps(
collections.namedtuple('GANTrainSteps', (
'generator_train_steps',
'discriminator_train_steps'
))):
"""Contains configuration for the GAN Training.
Args:
generator_train_steps: Number of generator steps to take in each GAN step.
discriminator_train_steps: Number of discriminator steps to take in each GAN
step.
"""

View File

@ -0,0 +1,804 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""The TFGAN project provides a lightweight GAN training/testing framework.
See examples in `tensorflow_models` for details on how to use.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import losses as tfgan_losses
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.contrib.training.python.training import training
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import check_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops.distributions import distribution as ds
from tensorflow.python.ops.losses import losses
from tensorflow.python.training import session_run_hook
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
__all__ = [
'gan_model',
'infogan_model',
'acgan_model',
'gan_loss',
'gan_train_ops',
'gan_train',
'get_sequential_train_hooks',
'get_joint_train_hooks',
'get_sequential_train_steps',
]
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
if isinstance(tensor_or_l_or_d, (list, tuple)):
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
elif isinstance(tensor_or_l_or_d, dict):
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
else:
return ops.convert_to_tensor(tensor_or_l_or_d)
def gan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
# Options.
check_shapes=True):
"""Returns GAN model outputs and variables.
Args:
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
real_data: A Tensor representing the real data.
generator_inputs: A Tensor or list of Tensors to the generator. In the
vanilla GAN case, this might be a single noise Tensor. In the conditional
GAN case, this might be the generator's conditioning.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as real data. Otherwise, skip this check.
Returns:
A GANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
discriminator_gen_outputs = discriminator_fn(generated_data,
generator_inputs)
with variable_scope.variable_scope(dis_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.shape, real_data.shape))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
return namedtuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
discriminator_real_outputs,
discriminator_gen_outputs,
discriminator_variables,
dis_scope,
discriminator_fn)
def _validate_distributions(distributions_l, noise_l):
if not isinstance(distributions_l, (tuple, list)):
raise ValueError('`predicted_distributions` must be a list. Instead, found '
'%s.' % type(distributions_l))
for dist in distributions_l:
if not isinstance(dist, ds.Distribution):
raise ValueError('Every element in `predicted_distributions` must be a '
'`tf.Distribution`. Instead, found %s.' % type(dist))
if len(distributions_l) != len(noise_l):
raise ValueError('Length of `predicted_distributions` %i must be the same '
'as the length of structured noise %i.' %
(len(distributions_l), len(noise_l)))
def infogan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
unstructured_generator_inputs,
structured_generator_inputs,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator'):
"""Returns an InfoGAN model outputs and variables.
See https://arxiv.org/abs/1606.03657 for more details.
Args:
generator_fn: A python lambda that takes a list of Tensors as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
`logits` are in the range [-inf, inf], and `distribution_list` is a list
of Tensorflow distributions representing the predicted noise distribution
of the ith structure noise.
real_data: A Tensor representing the real data.
unstructured_generator_inputs: A list of Tensors to the generator.
These tensors represent the unstructured noise or conditioning.
structured_generator_inputs: A list of Tensors to the generator.
These tensors must have high mutual information with the recognizer.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
Returns:
An InfoGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
ValueError: If the discriminator output is malformed.
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
unstructured_generator_inputs = _convert_tensor_or_l_or_d(
unstructured_generator_inputs)
structured_generator_inputs = _convert_tensor_or_l_or_d(
structured_generator_inputs)
generator_inputs = (
unstructured_generator_inputs + structured_generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as disc_scope:
dis_gen_outputs, predicted_distributions = discriminator_fn(
generated_data, generator_inputs)
_validate_distributions(predicted_distributions, structured_generator_inputs)
with variable_scope.variable_scope(disc_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.get_shape(), real_data.get_shape()))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(
disc_scope)
return namedtuples.InfoGANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data,
dis_real_outputs,
dis_gen_outputs,
discriminator_variables,
disc_scope,
lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
structured_generator_inputs,
predicted_distributions)
def _validate_acgan_discriminator_outputs(discriminator_output):
try:
a, b = discriminator_output
except (TypeError, ValueError):
raise TypeError(
'A discriminator function for ACGAN must output a tuple '
'consisting of (discrimination logits, classification logits).')
return a, b
def acgan_model(
# Lambdas defining models.
generator_fn,
discriminator_fn,
# Real data and conditioning.
real_data,
generator_inputs,
one_hot_labels,
# Optional scopes.
generator_scope='Generator',
discriminator_scope='Discriminator',
check_shapes=True):
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
The `acgan_model` is the same as the `gan_model` with the only difference
being that the discriminator additionally outputs logits to classify the input
(real or generated).
Therefore, an explicit field holding one_hot_labels is necessary, as well as a
discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
classification.
See https://arxiv.org/abs/1610.09585 for more details.
Args:
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a tuple consisting of two Tensors:
(1) real/fake logits in the range [-inf, inf]
(2) classification logits in the range [-inf, inf]
real_data: A Tensor representing the real data.
generator_inputs: A Tensor or list of Tensors to the generator. In the
vanilla GAN case, this might be a single noise Tensor. In the conditional
GAN case, this might be the generator's conditioning.
one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
acgan_loss.
generator_scope: Optional generator variable scope. Useful if you want to
reuse a subgraph that has already been created.
discriminator_scope: Optional discriminator variable scope. Useful if you
want to reuse a subgraph that has already been created.
check_shapes: If `True`, check that generator produces Tensors that are the
same shape as real data. Otherwise, skip this check.
Returns:
A ACGANModel namedtuple.
Raises:
ValueError: If the generator outputs a Tensor that isn't the same shape as
`real_data`.
TypeError: If the discriminator does not output a tuple consisting of
(discrimination logits, classification logits).
"""
# Create models
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
generated_data = generator_fn(generator_inputs)
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
(discriminator_gen_outputs, discriminator_gen_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(generated_data, generator_inputs))
with variable_scope.variable_scope(dis_scope, reuse=True):
real_data = ops.convert_to_tensor(real_data)
(discriminator_real_outputs, discriminator_real_classification_logits
) = _validate_acgan_discriminator_outputs(
discriminator_fn(real_data, generator_inputs))
if check_shapes:
if not generated_data.shape.is_compatible_with(real_data.shape):
raise ValueError(
'Generator output shape (%s) must be the same shape as real data '
'(%s).' % (generated_data.shape, real_data.shape))
# Get model-specific variables.
generator_variables = variables_lib.get_trainable_variables(gen_scope)
discriminator_variables = variables_lib.get_trainable_variables(
dis_scope)
return namedtuples.ACGANModel(
generator_inputs, generated_data, generator_variables, gen_scope,
generator_fn, real_data, discriminator_real_outputs,
discriminator_gen_outputs, discriminator_variables, dis_scope,
discriminator_fn, one_hot_labels,
discriminator_real_classification_logits,
discriminator_gen_classification_logits)
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
if isinstance(aux_loss_weight, ops.Tensor):
aux_loss_weight.shape.assert_is_compatible_with([])
with ops.control_dependencies(
[check_ops.assert_greater_equal(aux_loss_weight, 0.0)]):
aux_loss_weight = array_ops.identity(aux_loss_weight)
elif aux_loss_weight is not None and aux_loss_weight < 0:
raise ValueError('`%s` must be greater than 0. Instead, was %s' %
(name, aux_loss_weight))
return aux_loss_weight
def _use_aux_loss(aux_loss_weight):
if aux_loss_weight is not None:
if not isinstance(aux_loss_weight, ops.Tensor):
return aux_loss_weight > 0
else:
return True
else:
return False
def gan_loss(
# GANModel.
model,
# Loss functions.
generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
# Auxiliary losses.
gradient_penalty_weight=None,
gradient_penalty_epsilon=1e-10,
mutual_information_penalty_weight=None,
aux_cond_generator_weight=None,
aux_cond_discriminator_weight=None,
# Options.
add_summaries=True):
"""Returns losses necessary to train generator and discriminator.
Args:
model: A GANModel tuple.
generator_loss_fn: The loss function on the generator. Takes a GANModel
tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
GANModel tuple.
gradient_penalty_weight: If not `None`, must be a non-negative Python number
or Tensor indicating how much to weight the gradient penalty. See
https://arxiv.org/pdf/1704.00028.pdf for more details.
gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
small positive value used by the gradient penalty function for numerical
stability. Note some applications will need to increase this value to
avoid NaNs.
mutual_information_penalty_weight: If not `None`, must be a non-negative
Python number or Tensor indicating how much to weight the mutual
information penalty. See https://arxiv.org/abs/1606.03657 for more
details.
aux_cond_generator_weight: If not None: add a classification loss as in
https://arxiv.org/abs/1610.09585
aux_cond_discriminator_weight: If not None: add a classification loss as in
https://arxiv.org/abs/1610.09585
add_summaries: Whether or not to add summaries for the losses.
Returns:
A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
regularization losses.
Raises:
ValueError: If any of the auxiliary loss weights is provided and negative.
ValueError: If `mutual_information_penalty_weight` is provided, but the
`model` isn't an `InfoGANModel`.
"""
# Validate arguments.
gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight,
'gradient_penalty_weight')
mutual_information_penalty_weight = _validate_aux_loss_weight(
mutual_information_penalty_weight, 'infogan_weight')
aux_cond_generator_weight = _validate_aux_loss_weight(
aux_cond_generator_weight, 'aux_cond_generator_weight')
aux_cond_discriminator_weight = _validate_aux_loss_weight(
aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
# Verify configuration for mutual information penalty
if (_use_aux_loss(mutual_information_penalty_weight) and
not isinstance(model, namedtuples.InfoGANModel)):
raise ValueError(
'When `mutual_information_penalty_weight` is provided, `model` must be '
'an `InfoGANModel`. Instead, was %s.' % type(model))
# Verify configuration for mutual auxiliary condition loss (ACGAN).
if ((_use_aux_loss(aux_cond_generator_weight) or
_use_aux_loss(aux_cond_discriminator_weight)) and
not isinstance(model, namedtuples.ACGANModel)):
raise ValueError(
'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
type(model))
# Create standard losses.
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
# Add optional extra losses.
if _use_aux_loss(gradient_penalty_weight):
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
dis_loss += gradient_penalty_weight * gp_loss
if _use_aux_loss(mutual_information_penalty_weight):
info_loss = tfgan_losses.mutual_information_penalty(
model, add_summaries=add_summaries)
dis_loss += mutual_information_penalty_weight * info_loss
gen_loss += mutual_information_penalty_weight * info_loss
if _use_aux_loss(aux_cond_generator_weight):
ac_gen_loss = tfgan_losses.acgan_generator_loss(
model, add_summaries=add_summaries)
gen_loss += aux_cond_generator_weight * ac_gen_loss
if _use_aux_loss(aux_cond_discriminator_weight):
ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
model, add_summaries=add_summaries)
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
# Gathers auxilliary losses.
if model.generator_scope:
gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
else:
gen_reg_loss = 0
if model.discriminator_scope:
dis_reg_loss = losses.get_regularization_loss(
model.discriminator_scope.name)
else:
dis_reg_loss = 0
return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
"""Gets generator and discriminator update ops.
Args:
kwargs: A dictionary of kwargs to be passed to `create_train_op`.
`update_ops` is removed, if present.
gen_scope: A scope for the generator.
dis_scope: A scope for the discriminator.
check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
unused update ops.
Returns:
A 2-tuple of (generator update ops, discriminator train ops).
Raises:
ValueError: If there are update ops outside of the generator or
discriminator scopes.
"""
if 'update_ops' in kwargs:
update_ops = set(kwargs['update_ops'])
del kwargs['update_ops']
else:
update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope))
all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope))
if check_for_unused_ops:
unused_ops = update_ops - all_gen_ops - all_dis_ops
if unused_ops:
raise ValueError('There are unused update ops: %s' % unused_ops)
gen_update_ops = list(all_gen_ops & update_ops)
dis_update_ops = list(all_dis_ops & update_ops)
return gen_update_ops, dis_update_ops
def gan_train_ops(
model, # GANModel
loss, # GANLoss
generator_optimizer,
discriminator_optimizer,
# Optional check flags.
check_for_unused_update_ops=True,
# Optional args to pass directly to the `create_train_op`.
**kwargs):
"""Returns GAN train ops.
The highest-level call in TFGAN. It is composed of functions that can also
be called, should a user require more control over some part of the GAN
training process.
Args:
model: A GANModel.
loss: A GANLoss.
generator_optimizer: The optimizer for generator updates.
discriminator_optimizer: The optimizer for the discriminator updates.
check_for_unused_update_ops: If `True`, throws an exception if there are
update ops outside of the generator or discriminator scopes.
**kwargs: Keyword args to pass directly to
`training.create_train_op` for both the generator and
discriminator train op.
Returns:
A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
be used to train a generator/discriminator pair.
"""
# Create global step increment op.
global_step = training_util.get_or_create_global_step()
global_step_inc = global_step.assign_add(1)
# Get generator and discriminator update ops. We split them so that update
# ops aren't accidentally run multiple times. For now, throw an error if
# there are update ops that aren't associated with either the generator or
# the discriminator. Might modify the `kwargs` dictionary.
gen_update_ops, dis_update_ops = _get_update_ops(
kwargs, model.generator_scope.name, model.discriminator_scope.name,
check_for_unused_update_ops)
generator_global_step = None
if isinstance(generator_optimizer,
sync_replicas_optimizer.SyncReplicasOptimizer):
# TODO(joelshor): Figure out a way to get this work without including the
# dummy global step in the checkpoint.
# WARNING: Making this variable a local variable causes sync replicas to
# hang forever.
generator_global_step = variable_scope.get_variable(
'dummy_global_step_generator',
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
gen_update_ops += [generator_global_step.assign(global_step)]
with ops.name_scope('generator_train'):
gen_train_op = training.create_train_op(
total_loss=loss.generator_loss,
optimizer=generator_optimizer,
variables_to_train=model.generator_variables,
global_step=generator_global_step,
update_ops=gen_update_ops,
**kwargs)
discriminator_global_step = None
if isinstance(discriminator_optimizer,
sync_replicas_optimizer.SyncReplicasOptimizer):
# See comment above `generator_global_step`.
discriminator_global_step = variable_scope.get_variable(
'dummy_global_step_discriminator',
shape=[],
dtype=dtypes.int64,
initializer=init_ops.zeros_initializer(),
trainable=False,
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
dis_update_ops += [discriminator_global_step.assign(global_step)]
with ops.name_scope('discriminator_train'):
disc_train_op = training.create_train_op(
total_loss=loss.discriminator_loss,
optimizer=discriminator_optimizer,
variables_to_train=model.discriminator_variables,
global_step=discriminator_global_step,
update_ops=dis_update_ops,
**kwargs)
return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
# Image Compression` (https://arxiv.org/abs/1705.05823)
class RunTrainOpsHook(session_run_hook.SessionRunHook):
"""A hook to run train ops a fixed number of times."""
def __init__(self, train_ops, train_steps):
"""Run train ops a certain number of times.
Args:
train_ops: A train op or iterable of train ops to run.
train_steps: The number of times to run the op(s).
"""
if not isinstance(train_ops, (list, tuple)):
train_ops = [train_ops]
self._train_ops = train_ops
self._train_steps = train_steps
def before_run(self, run_context):
for _ in range(self._train_steps):
run_context.session.run(self._train_ops)
def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a hooks function for sequential GAN training.
Args:
train_steps: A `GANTrainSteps` tuple that determines how many generator
and discriminator training steps to take.
Returns:
A function that takes a GANTrainOps tuple and returns a list of hooks.
"""
def get_hooks(train_ops):
generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
train_steps.generator_train_steps)
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
train_steps.discriminator_train_steps)
return [generator_hook, discriminator_hook]
return get_hooks
def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a hooks function for sequential GAN training.
When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
The order of steps taken is:
1) Combined generator and discriminator steps
2) Generator only steps, if any remain
3) Discriminator only steps, if any remain
**NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
for the generator and discriminator simultaneously whenever possible. This
reduces the number of `tf.Session` calls, and can also change the training
semantics.
To illustrate the difference look at the following example:
`train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
`get_sequential_train_hooks` to make 8 session calls:
1) 3 generator steps
2) 5 discriminator steps
In contrast, `get_joint_train_steps` will make 5 session calls:
1) 3 generator + discriminator steps
2) 2 discriminator steps
Args:
train_steps: A `GANTrainSteps` tuple that determines how many generator
and discriminator training steps to take.
Returns:
A function that takes a GANTrainOps tuple and returns a list of hooks.
"""
g_steps = train_steps.generator_train_steps
d_steps = train_steps.discriminator_train_steps
# Get the number of each type of step that should be run.
num_d_and_g_steps = min(g_steps, d_steps)
num_g_steps = g_steps - num_d_and_g_steps
num_d_steps = d_steps - num_d_and_g_steps
def get_hooks(train_ops):
g_op = train_ops.generator_train_op
d_op = train_ops.discriminator_train_op
joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
g_hook = RunTrainOpsHook(g_op, num_g_steps)
d_hook = RunTrainOpsHook(d_op, num_d_steps)
return [joint_hook, g_hook, d_hook]
return get_hooks
# TODO(joelshor): This function currently returns the global step. Find a
# good way for it to return the generator, discriminator, and final losses.
def gan_train(
train_ops,
logdir,
get_hooks_fn=get_sequential_train_hooks(),
master='',
is_chief=True,
scaffold=None,
hooks=None,
chief_only_hooks=None,
save_checkpoint_secs=600,
save_summaries_steps=100,
config=None):
"""A wrapper around `contrib.training.train` that uses GAN hooks.
Args:
train_ops: A GANTrainOps named tuple.
logdir: The directory where the graph and checkpoints are saved.
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
of hooks.
master: The URL of the master.
is_chief: Specifies whether or not the training is being run by the primary
replica during replica training.
scaffold: An tf.train.Scaffold instance.
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
training loop.
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
inside the training loop for the chief trainer only.
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
using a default checkpoint saver. If `save_checkpoint_secs` is set to
`None`, then the default checkpoint saver isn't used.
save_summaries_steps: The frequency, in number of global steps, that the
summaries are written to disk using a default summary saver. If
`save_summaries_steps` is set to `None`, then the default summary saver
isn't used.
config: An instance of `tf.ConfigProto`.
Returns:
Output of the call to `training.train`.
"""
new_hooks = get_hooks_fn(train_ops)
if hooks is not None:
hooks = list(hooks) + list(new_hooks)
else:
hooks = new_hooks
return training.train(
train_ops.global_step_inc_op,
logdir,
master=master,
is_chief=is_chief,
scaffold=scaffold,
hooks=hooks,
chief_only_hooks=chief_only_hooks,
save_checkpoint_secs=save_checkpoint_secs,
save_summaries_steps=save_summaries_steps,
config=config)
def get_sequential_train_steps(
train_steps=namedtuples.GANTrainSteps(1, 1)):
"""Returns a thin wrapper around slim.learning.train_step, for GANs.
This function is to provide support for the Supervisor. For new code, please
use `MonitoredSession` and `get_sequential_train_hooks`.
Args:
train_steps: A `GANTrainSteps` tuple that determines how many generator
and discriminator training steps to take.
Returns:
A function that can be used for `train_step_fn` for GANs.
"""
def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
"""A thin wrapper around slim.learning.train_step, for GANs.
Args:
sess: A Tensorflow session.
train_ops: A GANTrainOps tuple of train ops to run.
global_step: The global step.
train_step_kwargs: Dictionary controlling `train_step` behavior.
Returns:
A scalar final loss and a bool whether or not the train loop should stop.
"""
# Only run `should_stop` at the end, if required. Make a local copy of
# `train_step_kwargs`, if necessary, so as not to modify the caller's
# dictionary.
should_stop_op, train_kwargs = None, train_step_kwargs
if 'should_stop' in train_step_kwargs:
should_stop_op = train_step_kwargs['should_stop']
train_kwargs = train_step_kwargs.copy()
del train_kwargs['should_stop']
# Run generator training steps.
gen_loss = 0
for _ in range(train_steps.generator_train_steps):
cur_gen_loss, _ = slim_learning.train_step(
sess, train_ops.generator_train_op, global_step, train_kwargs)
gen_loss += cur_gen_loss
# Run discriminator training steps.
dis_loss = 0
for _ in range(train_steps.discriminator_train_steps):
cur_dis_loss, _ = slim_learning.train_step(
sess, train_ops.discriminator_train_op, global_step, train_kwargs)
dis_loss += cur_dis_loss
sess.run(train_ops.global_step_inc_op)
# Run the `should_stop` op after the global step has been incremented, so
# that the `should_stop` aligns with the proper `global_step` count.
if should_stop_op is not None:
should_stop = sess.run(should_stop_op)
else:
should_stop = False
return gen_loss + dis_loss, should_stop
return sequential_train_steps

View File

@ -0,0 +1,745 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for gan.python.train."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.contrib.framework.python.ops import variables as variables_lib
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python import train
from tensorflow.contrib.slim.python.slim import learning as slim_learning
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
from tensorflow.python.ops.distributions import categorical
from tensorflow.python.platform import test
from tensorflow.python.training import basic_session_run_hooks
from tensorflow.python.training import coordinator
from tensorflow.python.training import gradient_descent
from tensorflow.python.training import sync_replicas_optimizer
from tensorflow.python.training import training_util
def generator_model(inputs):
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
class Generator(object):
def __call__(self, inputs):
return generator_model(inputs)
def infogan_generator_model(inputs):
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs[0]
class InfoGANGenerator(object):
def __call__(self, inputs):
return infogan_generator_model(inputs)
def discriminator_model(inputs, _):
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
class Discriminator(object):
def __call__(self, inputs, _):
return discriminator_model(inputs, _)
def infogan_discriminator_model(inputs, _):
return (variable_scope.get_variable('dummy_d', initializer=2.0) * inputs,
[categorical.Categorical([1.0])])
class InfoGANDiscriminator(object):
def __call__(self, inputs, _):
return infogan_discriminator_model(inputs, _)
def acgan_discriminator_model(inputs, _, num_classes=10):
return (discriminator_model(inputs, _), array_ops.one_hot(
# TODO(haeusser): infer batch size from input
random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
num_classes))
class ACGANDiscriminator(object):
def __call__(self, inputs, _, num_classes=10):
return (discriminator_model(inputs, _), array_ops.one_hot(
# TODO(haeusser): infer batch size from input
random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
num_classes))
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
pass
with variable_scope.variable_scope('discriminator') as dis_scope:
pass
return namedtuples.GANModel(
generator_inputs=None,
generated_data=None,
generator_variables=None,
generator_scope=gen_scope,
generator_fn=generator_model,
real_data=array_ops.ones([1, 2, 3]),
discriminator_real_outputs=array_ops.ones([1, 2, 3]),
discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
discriminator_variables=None,
discriminator_scope=dis_scope,
discriminator_fn=discriminator_model)
def get_callable_gan_model():
ganmodel = get_gan_model()
return ganmodel._replace(
generator_fn=Generator(),
discriminator_fn=Discriminator())
def create_gan_model():
return train.gan_model(
generator_model,
discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
def create_callable_gan_model():
return train.gan_model(
Generator(),
Discriminator(),
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
def get_infogan_model():
return namedtuples.InfoGANModel(
*get_gan_model(),
structured_generator_inputs=[constant_op.constant(0)],
predicted_distributions=[categorical.Categorical([1.0])])
def get_callable_infogan_model():
return namedtuples.InfoGANModel(
*get_callable_gan_model(),
structured_generator_inputs=[constant_op.constant(0)],
predicted_distributions=[categorical.Categorical([1.0])])
def create_infogan_model():
return train.infogan_model(
infogan_generator_model,
infogan_discriminator_model,
real_data=array_ops.zeros([1, 2]),
unstructured_generator_inputs=[],
structured_generator_inputs=[random_ops.random_normal([1, 2])])
def create_callable_infogan_model():
return train.infogan_model(
InfoGANGenerator(),
InfoGANDiscriminator(),
real_data=array_ops.zeros([1, 2]),
unstructured_generator_inputs=[],
structured_generator_inputs=[random_ops.random_normal([1, 2])])
def get_acgan_model():
return namedtuples.ACGANModel(
*get_gan_model(),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
def get_callable_acgan_model():
return namedtuples.ACGANModel(
*get_callable_gan_model(),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
def create_acgan_model():
return train.acgan_model(
generator_model,
acgan_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
def create_callable_acgan_model():
return train.acgan_model(
Generator(),
ACGANDiscriminator(),
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]),
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
def get_sync_optimizer():
return sync_replicas_optimizer.SyncReplicasOptimizer(
gradient_descent.GradientDescentOptimizer(learning_rate=1.0),
replicas_to_aggregate=1)
class GANModelTest(test.TestCase):
"""Tests for `gan_model`."""
def _test_output_type_helper(self, create_fn, tuple_type):
self.assertTrue(isinstance(create_fn(), tuple_type))
def test_output_type_gan(self):
self._test_output_type_helper(get_gan_model, namedtuples.GANModel)
def test_output_type_callable_gan(self):
self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel)
def test_output_type_infogan(self):
self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel)
def test_output_type_callable_infogan(self):
self._test_output_type_helper(
get_callable_infogan_model, namedtuples.InfoGANModel)
def test_output_type_acgan(self):
self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel)
def test_output_type_callable_acgan(self):
self._test_output_type_helper(
get_callable_acgan_model, namedtuples.ACGANModel)
def test_no_shape_check(self):
def dummy_generator_model(_):
return (None, None)
def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument
return 1
with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
train.gan_model(
dummy_generator_model,
dummy_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=array_ops.zeros([1]),
check_shapes=True)
train.gan_model(
dummy_generator_model,
dummy_discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=array_ops.zeros([1]),
check_shapes=False)
class GANLossTest(test.TestCase):
"""Tests for `gan_loss`."""
# Test output type.
def _test_output_type_helper(self, get_gan_model_fn):
loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
def test_output_type_gan(self):
self._test_output_type_helper(get_gan_model)
def test_output_type_callable_gan(self):
self._test_output_type_helper(get_callable_gan_model)
def test_output_type_infogan(self):
self._test_output_type_helper(get_infogan_model)
def test_output_type_callable_infogan(self):
self._test_output_type_helper(get_callable_infogan_model)
def test_output_type_acgan(self):
self._test_output_type_helper(get_acgan_model)
def test_output_type_callable_acgan(self):
self._test_output_type_helper(get_callable_acgan_model)
# Test gradient penalty option.
def _test_grad_penalty_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_gen_gp_np = sess.run(
[loss.generator_loss, loss_gp.generator_loss])
loss_dis_np, loss_dis_gp_np = sess.run(
[loss.discriminator_loss, loss_gp.discriminator_loss])
self.assertEqual(loss_gen_np, loss_gen_gp_np)
self.assertTrue(loss_dis_np < loss_dis_gp_np)
def test_grad_penalty_gan(self):
self._test_grad_penalty_helper(create_gan_model)
def test_grad_penalty_callable_gan(self):
self._test_grad_penalty_helper(create_callable_gan_model)
def test_grad_penalty_infogan(self):
self._test_grad_penalty_helper(create_infogan_model)
def test_grad_penalty_callable_infogan(self):
self._test_grad_penalty_helper(create_callable_infogan_model)
def test_grad_penalty_acgan(self):
self._test_grad_penalty_helper(create_acgan_model)
def test_grad_penalty_callable_acgan(self):
self._test_grad_penalty_helper(create_callable_acgan_model)
# Test mutual information penalty option.
def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
train.gan_loss(create_gan_model_fn(),
mutual_information_penalty_weight=constant_op.constant(1.0))
def test_mutual_info_penalty_infogan(self):
self._test_mutual_info_penalty_helper(get_infogan_model)
def test_mutual_info_penalty_callable_infogan(self):
self._test_mutual_info_penalty_helper(get_callable_infogan_model)
# Test regularization loss.
def _test_regularization_helper(self, get_gan_model_fn):
# Evaluate losses without regularization.
no_reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()
with ops.name_scope(get_gan_model_fn().generator_scope.name):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
ops.add_to_collection(
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
# Check that losses now include the correct regularization values.
reg_loss = train.gan_loss(get_gan_model_fn())
with self.test_session(use_gpu=True):
reg_loss_gen_np = reg_loss.generator_loss.eval()
reg_loss_dis_np = reg_loss.discriminator_loss.eval()
self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
def test_regularization_gan(self):
self._test_regularization_helper(get_gan_model)
def test_regularization_callable_gan(self):
self._test_regularization_helper(get_callable_gan_model)
def test_regularization_infogan(self):
self._test_regularization_helper(get_infogan_model)
def test_regularization_callable_infogan(self):
self._test_regularization_helper(get_callable_infogan_model)
def test_regularization_acgan(self):
self._test_regularization_helper(get_acgan_model)
def test_regularization_callable_acgan(self):
self._test_regularization_helper(get_callable_acgan_model)
# Test that ACGan models work.
def _test_acgan_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))
# Check values.
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
[loss.generator_loss,
loss_ac_gen.generator_loss,
loss_ac_dis.generator_loss])
loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
[loss.discriminator_loss,
loss_ac_gen.discriminator_loss,
loss_ac_dis.discriminator_loss])
self.assertTrue(loss_gen_np < loss_dis_np)
self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
def test_acgan(self):
self._test_acgan_helper(create_acgan_model)
def test_callable_acgan(self):
self._test_acgan_helper(create_callable_acgan_model)
def test_doesnt_crash_when_in_nested_scope(self):
with variable_scope.variable_scope('outer_scope'):
gan_model = train.gan_model(
generator_model,
discriminator_model,
real_data=array_ops.zeros([1, 2]),
generator_inputs=random_ops.random_normal([1, 2]))
# This should work inside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
# This should also work outside a scope.
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
class GANTrainOpsTest(test.TestCase):
"""Tests for `gan_train_ops`."""
def _test_output_type_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(
model,
loss,
g_opt,
d_opt,
summarize_gradients=True,
colocate_gradients_with_ops=True)
self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
def test_output_type_gan(self):
self._test_output_type_helper(create_gan_model)
def test_output_type_callable_gan(self):
self._test_output_type_helper(create_callable_gan_model)
def test_output_type_infogan(self):
self._test_output_type_helper(create_infogan_model)
def test_output_type_callable_infogan(self):
self._test_output_type_helper(create_callable_infogan_model)
def test_output_type_acgan(self):
self._test_output_type_helper(create_acgan_model)
def test_output_type_callable_acgan(self):
self._test_output_type_helper(create_callable_acgan_model)
# TODO(joelshor): Add a test to check that custom update op is run.
def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
model = create_gan_model_fn()
loss = train.gan_loss(model)
# Add generator and discriminator update ops.
with variable_scope.variable_scope(model.generator_scope):
gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
gen_update_op = gen_update_count.assign_add(1)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
with variable_scope.variable_scope(model.discriminator_scope):
dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
dis_update_op = dis_update_count.assign_add(1)
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)
# Add an update op outside the generator and discriminator scopes.
if provide_update_ops:
kwargs = {'update_ops':
[constant_op.constant(1.0), gen_update_op, dis_update_op]}
else:
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
kwargs = {}
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
train.gan_train_ops(model, loss, g_opt, d_opt,
check_for_unused_update_ops=True, **kwargs)
train_ops = train.gan_train_ops(
model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)
with self.test_session(use_gpu=True) as sess:
sess.run(variables.global_variables_initializer())
self.assertEqual(0, gen_update_count.eval())
self.assertEqual(0, dis_update_count.eval())
train_ops.generator_train_op.eval()
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(0, dis_update_count.eval())
train_ops.discriminator_train_op.eval()
self.assertEqual(1, gen_update_count.eval())
self.assertEqual(1, dis_update_count.eval())
def test_unused_update_ops_gan(self):
self._test_unused_update_ops(create_gan_model, False)
def test_unused_update_ops_gan_provideupdates(self):
self._test_unused_update_ops(create_gan_model, True)
def test_unused_update_ops_callable_gan(self):
self._test_unused_update_ops(create_callable_gan_model, False)
def test_unused_update_ops_callable_gan_provideupdates(self):
self._test_unused_update_ops(create_callable_gan_model, True)
def test_unused_update_ops_infogan(self):
self._test_unused_update_ops(create_infogan_model, False)
def test_unused_update_ops_infogan_provideupdates(self):
self._test_unused_update_ops(create_infogan_model, True)
def test_unused_update_ops_callable_infogan(self):
self._test_unused_update_ops(create_callable_infogan_model, False)
def test_unused_update_ops_callable_infogan_provideupdates(self):
self._test_unused_update_ops(create_callable_infogan_model, True)
def test_unused_update_ops_acgan(self):
self._test_unused_update_ops(create_acgan_model, False)
def test_unused_update_ops_acgan_provideupdates(self):
self._test_unused_update_ops(create_acgan_model, True)
def test_unused_update_ops_callable_acgan(self):
self._test_unused_update_ops(create_callable_acgan_model, False)
def test_unused_update_ops_callable_acgan_provideupdates(self):
self._test_unused_update_ops(create_callable_acgan_model, True)
def _test_sync_replicas_helper(self, create_gan_model_fn):
model = create_gan_model_fn()
loss = train.gan_loss(model)
num_trainable_vars = len(variables_lib.get_trainable_variables())
g_opt = get_sync_optimizer()
d_opt = get_sync_optimizer()
train_ops = train.gan_train_ops(
model,
loss,
generator_optimizer=g_opt,
discriminator_optimizer=d_opt)
self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
# No new trainable variables should have been added.
self.assertEqual(num_trainable_vars,
len(variables_lib.get_trainable_variables()))
g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
# Check that update op is run properly.
global_step = training_util.get_or_create_global_step()
with self.test_session(use_gpu=True) as sess:
variables.global_variables_initializer().run()
variables.local_variables_initializer().run()
g_opt.chief_init_op.run()
d_opt.chief_init_op.run()
gstep_before = global_step.eval()
# Start required queue runner for SyncReplicasOptimizer.
coord = coordinator.Coordinator()
g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)
g_sync_init_op.run()
d_sync_init_op.run()
train_ops.generator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
train_ops.discriminator_train_op.eval()
# Check that global step wasn't incremented.
self.assertEqual(gstep_before, global_step.eval())
coord.request_stop()
coord.join(g_threads + d_threads)
def test_sync_replicas_gan(self):
self._test_sync_replicas_helper(create_gan_model)
def test_sync_replicas_callable_gan(self):
self._test_sync_replicas_helper(create_callable_gan_model)
def test_sync_replicas_infogan(self):
self._test_sync_replicas_helper(create_infogan_model)
def test_sync_replicas_callable_infogan(self):
self._test_sync_replicas_helper(create_callable_infogan_model)
def test_sync_replicas_acgan(self):
self._test_sync_replicas_helper(create_acgan_model)
def test_sync_replicas_callable_acgan(self):
self._test_sync_replicas_helper(create_callable_acgan_model)
class GANTrainTest(test.TestCase):
"""Tests for `gan_train`."""
def _gan_train_ops(self, generator_add, discriminator_add):
step = training_util.create_global_step()
# Increment the global count every time a train op is run so we can count
# the number of times they're run.
# NOTE: `use_locking=True` is required to avoid race conditions with
# joint training.
train_ops = namedtuples.GANTrainOps(
generator_train_op=step.assign_add(generator_add, use_locking=True),
discriminator_train_op=step.assign_add(discriminator_add,
use_locking=True),
global_step_inc_op=step.assign_add(1))
return train_ops
def _test_run_helper(self, create_gan_model_fn):
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
final_step = train.gan_train(
train_ops,
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
def test_run_gan(self):
self._test_run_helper(create_gan_model)
def test_run_callable_gan(self):
self._test_run_helper(create_callable_gan_model)
def test_run_infogan(self):
self._test_run_helper(create_infogan_model)
def test_run_callable_infogan(self):
self._test_run_helper(create_callable_infogan_model)
def test_run_acgan(self):
self._test_run_helper(create_acgan_model)
def test_run_callable_acgan(self):
self._test_run_helper(create_callable_acgan_model)
# Test multiple train steps.
def _test_multiple_steps_helper(self, get_hooks_fn_fn):
train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
train_steps = namedtuples.GANTrainSteps(
generator_train_steps=3,
discriminator_train_steps=4)
final_step = train.gan_train(
train_ops,
get_hooks_fn=get_hooks_fn_fn(train_steps),
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
def test_multiple_steps_seq_train_steps(self):
self._test_multiple_steps_helper(train.get_sequential_train_hooks)
def test_multiple_steps_efficient_seq_train_steps(self):
self._test_multiple_steps_helper(train.get_joint_train_hooks)
def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
step = training_util.create_global_step()
train_ops = namedtuples.GANTrainOps(
generator_train_op=constant_op.constant(3.0),
discriminator_train_op=constant_op.constant(2.0),
global_step_inc_op=step.assign_add(1))
train_steps = namedtuples.GANTrainSteps(
generator_train_steps=3,
discriminator_train_steps=4)
final_loss = slim_learning.train(
train_op=train_ops,
logdir='',
global_step=step,
number_of_steps=1,
train_step_fn=train.get_sequential_train_steps(train_steps))
self.assertTrue(np.isscalar(final_loss))
self.assertEqual(17.0, final_loss)
class PatchGANTest(test.TestCase):
"""Tests that functions work on PatchGAN style output."""
def _test_patchgan_helper(self, create_gan_model_fn):
"""Ensure that patch-based discriminators work end-to-end."""
random_seed.set_random_seed(1234)
model = create_gan_model_fn()
loss = train.gan_loss(model)
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
final_step = train.gan_train(
train_ops,
logdir='',
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
self.assertTrue(np.isscalar(final_step))
self.assertEqual(2, final_step)
def test_patchgan_gan(self):
self._test_patchgan_helper(create_gan_model)
def test_patchgan_callable_gan(self):
self._test_patchgan_helper(create_callable_gan_model)
def test_patchgan_infogan(self):
self._test_patchgan_helper(create_infogan_model)
def test_patchgan_callable_infogan(self):
self._test_patchgan_helper(create_callable_infogan_model)
def test_patchgan_acgan(self):
self._test_patchgan_helper(create_acgan_model)
def test_patchgan_callable_acgan(self):
self._test_patchgan_helper(create_callable_acgan_model)
if __name__ == '__main__':
test.main()