diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index cb2cd7c7ef0..c3ae738acf7 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -16,6 +16,60 @@ py_library( deps = [ ":features", ":losses", + ":namedtuples", + ":train", + ], +) + +py_library( + name = "namedtuples", + srcs = ["python/namedtuples.py"], + srcs_version = "PY2AND3", +) + +py_library( + name = "train", + srcs = ["python/train.py"], + srcs_version = "PY2AND3", + deps = [ + ":losses", + ":namedtuples", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/slim:learning", + "//tensorflow/contrib/training:training_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:check_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/ops/distributions", + "//tensorflow/python/ops/losses", + ], +) + +py_test( + name = "train_test", + srcs = ["python/train_test.py"], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":train", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/contrib/slim:learning", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:constant_op", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:random_ops", + "//tensorflow/python:random_seed", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python:variables", + "//tensorflow/python/ops/distributions", + "//third_party/py/numpy", ], ) diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index b2f4bf01190..3c423e72d0c 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -1,4 +1,4 @@ -# Copyright 2016 Google Inc. All Rights Reserved. +# Copyright 2017 Google Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -21,7 +21,20 @@ from __future__ import print_function # Collapse TFGAN into a tiered namespace. from tensorflow.contrib.gan.python import features from tensorflow.contrib.gan.python import losses +from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python import train -del absolute_import -del division -del print_function +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.gan.python.namedtuples import * +from tensorflow.contrib.gan.python.train import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'features', + 'losses', +] +_allowed_symbols += train.__all__ +_allowed_symbols += namedtuples.__all__ +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/namedtuples.py b/tensorflow/contrib/gan/python/namedtuples.py new file mode 100644 index 00000000000..a99e3fbec8d --- /dev/null +++ b/tensorflow/contrib/gan/python/namedtuples.py @@ -0,0 +1,149 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Named tuples for TFGAN.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import collections + + +__all__ = [ + 'GANModel', + 'InfoGANModel', + 'ACGANModel', + 'GANLoss', + 'GANTrainOps', + 'GANTrainSteps', +] + + +class GANModel( + collections.namedtuple('GANModel', ( + 'generator_inputs', + 'generated_data', + 'generator_variables', + 'generator_scope', + 'generator_fn', + 'real_data', + 'discriminator_real_outputs', + 'discriminator_gen_outputs', + 'discriminator_variables', + 'discriminator_scope', + 'discriminator_fn', + ))): + """A GANModel contains all the pieces needed for GAN training. + + Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt + to create an implicit generative model of data by solving a two agent game. + The generator generates candidate examples that are supposed to match the + data distribution, and the discriminator aims to tell the real examples + apart from the generated samples. + + Args: + generator_inputs: The random noise source that acts as input to the + generator. + generated_data: The generated output data of the GAN. + generator_variables: A list of all generator variables. + generator_scope: Variable scope all generator variables live in. + generator_fn: The generator function. + real_data: A tensor or real data. + discriminator_real_outputs: The discriminator's output on real data. + discriminator_gen_outputs: The discriminator's output on generated data. + discriminator_variables: A list of all discriminator variables. + discriminator_scope: Variable scope all discriminator variables live in. + discriminator_fn: The discriminator function. + """ + + +# TODO(joelshor): Have this class inherit from `GANModel`. +class InfoGANModel( + collections.namedtuple('InfoGANModel', GANModel._fields + ( + 'structured_generator_inputs', + 'predicted_distributions', + ))): + """An InfoGANModel contains all the pieces needed for InfoGAN training. + + See https://arxiv.org/abs/1606.03657 for more details. + + Args: + structured_generator_inputs: A list of Tensors representing the random noise + that must have high mutual information with the generator output. List + length should match `predicted_distributions`. + predicted_distributions: A list of tf.Distributions. Predicted by the + recognizer, and used to evaluate the likelihood of the structured noise. + List length should match `structured_generator_inputs`. + """ + + +class ACGANModel( + collections.namedtuple('ACGANModel', GANModel._fields + + ('one_hot_labels', + 'discriminator_real_classification_logits', + 'discriminator_gen_classification_logits',))): + """An ACGANModel contains all the pieces needed for ACGAN training. + + See https://arxiv.org/abs/1610.09585 for more details. + + Args: + one_hot_labels: A Tensor holding one-hot-labels for the batch. + discriminator_real_classification_logits: Classification logits for real + data. + discriminator_gen_classification_logits: Classification logits for generated + data. + """ + + +class GANLoss( + collections.namedtuple('GANLoss', ( + 'generator_loss', + 'discriminator_loss' + ))): + """GANLoss contains the generator and discriminator losses. + + Args: + generator_loss: A tensor for the generator loss.. + discriminator_loss: A tensor for the discriminator loss. + """ + + +class GANTrainOps( + collections.namedtuple('GANTrainOps', ( + 'generator_train_op', + 'discriminator_train_op', + 'global_step_inc_op' + ))): + """GANTrainOps contains the training ops. + + Args: + generator_train_op: Op that performs a generator update step. + discriminator_train_op: Op that performs a discriminator update step. + global_step_inc_op: Op that increments the shared global step. + """ + + +class GANTrainSteps( + collections.namedtuple('GANTrainSteps', ( + 'generator_train_steps', + 'discriminator_train_steps' + ))): + """Contains configuration for the GAN Training. + + Args: + generator_train_steps: Number of generator steps to take in each GAN step. + discriminator_train_steps: Number of discriminator steps to take in each GAN + step. + """ diff --git a/tensorflow/contrib/gan/python/train.py b/tensorflow/contrib/gan/python/train.py new file mode 100644 index 00000000000..af7dbcf249a --- /dev/null +++ b/tensorflow/contrib/gan/python/train.py @@ -0,0 +1,804 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""The TFGAN project provides a lightweight GAN training/testing framework. + +See examples in `tensorflow_models` for details on how to use. +""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.framework.python.ops import variables as variables_lib +from tensorflow.contrib.gan.python import losses as tfgan_losses +from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.slim.python.slim import learning as slim_learning +from tensorflow.contrib.training.python.training import training +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import check_ops +from tensorflow.python.ops import init_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops.distributions import distribution as ds +from tensorflow.python.ops.losses import losses +from tensorflow.python.training import session_run_hook +from tensorflow.python.training import sync_replicas_optimizer +from tensorflow.python.training import training_util + + +__all__ = [ + 'gan_model', + 'infogan_model', + 'acgan_model', + 'gan_loss', + 'gan_train_ops', + 'gan_train', + 'get_sequential_train_hooks', + 'get_joint_train_hooks', + 'get_sequential_train_steps', +] + + +def _convert_tensor_or_l_or_d(tensor_or_l_or_d): + """Convert input, list of inputs, or dictionary of inputs to Tensors.""" + if isinstance(tensor_or_l_or_d, (list, tuple)): + return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] + elif isinstance(tensor_or_l_or_d, dict): + return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} + else: + return ops.convert_to_tensor(tensor_or_l_or_d) + + +def gan_model( + # Lambdas defining models. + generator_fn, + discriminator_fn, + # Real data and conditioning. + real_data, + generator_inputs, + # Optional scopes. + generator_scope='Generator', + discriminator_scope='Discriminator', + # Options. + check_shapes=True): + """Returns GAN model outputs and variables. + + Args: + generator_fn: A python lambda that takes `generator_inputs` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. + real_data: A Tensor representing the real data. + generator_inputs: A Tensor or list of Tensors to the generator. In the + vanilla GAN case, this might be a single noise Tensor. In the conditional + GAN case, this might be the generator's conditioning. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. + check_shapes: If `True`, check that generator produces Tensors that are the + same shape as real data. Otherwise, skip this check. + + Returns: + A GANModel namedtuple. + + Raises: + ValueError: If the generator outputs a Tensor that isn't the same shape as + `real_data`. + """ + # Create models + with variable_scope.variable_scope(generator_scope) as gen_scope: + generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) + generated_data = generator_fn(generator_inputs) + with variable_scope.variable_scope(discriminator_scope) as dis_scope: + discriminator_gen_outputs = discriminator_fn(generated_data, + generator_inputs) + with variable_scope.variable_scope(dis_scope, reuse=True): + real_data = ops.convert_to_tensor(real_data) + discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) + + if check_shapes: + if not generated_data.shape.is_compatible_with(real_data.shape): + raise ValueError( + 'Generator output shape (%s) must be the same shape as real data ' + '(%s).' % (generated_data.shape, real_data.shape)) + + # Get model-specific variables. + generator_variables = variables_lib.get_trainable_variables(gen_scope) + discriminator_variables = variables_lib.get_trainable_variables(dis_scope) + + return namedtuples.GANModel( + generator_inputs, + generated_data, + generator_variables, + gen_scope, + generator_fn, + real_data, + discriminator_real_outputs, + discriminator_gen_outputs, + discriminator_variables, + dis_scope, + discriminator_fn) + + +def _validate_distributions(distributions_l, noise_l): + if not isinstance(distributions_l, (tuple, list)): + raise ValueError('`predicted_distributions` must be a list. Instead, found ' + '%s.' % type(distributions_l)) + for dist in distributions_l: + if not isinstance(dist, ds.Distribution): + raise ValueError('Every element in `predicted_distributions` must be a ' + '`tf.Distribution`. Instead, found %s.' % type(dist)) + if len(distributions_l) != len(noise_l): + raise ValueError('Length of `predicted_distributions` %i must be the same ' + 'as the length of structured noise %i.' % + (len(distributions_l), len(noise_l))) + + +def infogan_model( + # Lambdas defining models. + generator_fn, + discriminator_fn, + # Real data and conditioning. + real_data, + unstructured_generator_inputs, + structured_generator_inputs, + # Optional scopes. + generator_scope='Generator', + discriminator_scope='Discriminator'): + """Returns an InfoGAN model outputs and variables. + + See https://arxiv.org/abs/1606.03657 for more details. + + Args: + generator_fn: A python lambda that takes a list of Tensors as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list). + `logits` are in the range [-inf, inf], and `distribution_list` is a list + of Tensorflow distributions representing the predicted noise distribution + of the ith structure noise. + real_data: A Tensor representing the real data. + unstructured_generator_inputs: A list of Tensors to the generator. + These tensors represent the unstructured noise or conditioning. + structured_generator_inputs: A list of Tensors to the generator. + These tensors must have high mutual information with the recognizer. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. + + Returns: + An InfoGANModel namedtuple. + + Raises: + ValueError: If the generator outputs a Tensor that isn't the same shape as + `real_data`. + ValueError: If the discriminator output is malformed. + """ + # Create models + with variable_scope.variable_scope(generator_scope) as gen_scope: + unstructured_generator_inputs = _convert_tensor_or_l_or_d( + unstructured_generator_inputs) + structured_generator_inputs = _convert_tensor_or_l_or_d( + structured_generator_inputs) + generator_inputs = ( + unstructured_generator_inputs + structured_generator_inputs) + generated_data = generator_fn(generator_inputs) + with variable_scope.variable_scope(discriminator_scope) as disc_scope: + dis_gen_outputs, predicted_distributions = discriminator_fn( + generated_data, generator_inputs) + _validate_distributions(predicted_distributions, structured_generator_inputs) + with variable_scope.variable_scope(disc_scope, reuse=True): + real_data = ops.convert_to_tensor(real_data) + dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs) + + if not generated_data.get_shape().is_compatible_with(real_data.get_shape()): + raise ValueError( + 'Generator output shape (%s) must be the same shape as real data ' + '(%s).' % (generated_data.get_shape(), real_data.get_shape())) + + # Get model-specific variables. + generator_variables = variables_lib.get_trainable_variables(gen_scope) + discriminator_variables = variables_lib.get_trainable_variables( + disc_scope) + + return namedtuples.InfoGANModel( + generator_inputs, + generated_data, + generator_variables, + gen_scope, + generator_fn, + real_data, + dis_real_outputs, + dis_gen_outputs, + discriminator_variables, + disc_scope, + lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API + structured_generator_inputs, + predicted_distributions) + + +def _validate_acgan_discriminator_outputs(discriminator_output): + try: + a, b = discriminator_output + except (TypeError, ValueError): + raise TypeError( + 'A discriminator function for ACGAN must output a tuple ' + 'consisting of (discrimination logits, classification logits).') + return a, b + + +def acgan_model( + # Lambdas defining models. + generator_fn, + discriminator_fn, + # Real data and conditioning. + real_data, + generator_inputs, + one_hot_labels, + # Optional scopes. + generator_scope='Generator', + discriminator_scope='Discriminator', + check_shapes=True): + """Returns an ACGANModel contains all the pieces needed for ACGAN training. + + The `acgan_model` is the same as the `gan_model` with the only difference + being that the discriminator additionally outputs logits to classify the input + (real or generated). + Therefore, an explicit field holding one_hot_labels is necessary, as well as a + discriminator_fn that outputs a 2-tuple holding the logits for real/fake and + classification. + + See https://arxiv.org/abs/1610.09585 for more details. + + Args: + generator_fn: A python lambda that takes `generator_inputs` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a tuple consisting of two Tensors: + (1) real/fake logits in the range [-inf, inf] + (2) classification logits in the range [-inf, inf] + real_data: A Tensor representing the real data. + generator_inputs: A Tensor or list of Tensors to the generator. In the + vanilla GAN case, this might be a single noise Tensor. In the conditional + GAN case, this might be the generator's conditioning. + one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by + acgan_loss. + generator_scope: Optional generator variable scope. Useful if you want to + reuse a subgraph that has already been created. + discriminator_scope: Optional discriminator variable scope. Useful if you + want to reuse a subgraph that has already been created. + check_shapes: If `True`, check that generator produces Tensors that are the + same shape as real data. Otherwise, skip this check. + + Returns: + A ACGANModel namedtuple. + + Raises: + ValueError: If the generator outputs a Tensor that isn't the same shape as + `real_data`. + TypeError: If the discriminator does not output a tuple consisting of + (discrimination logits, classification logits). + """ + # Create models + with variable_scope.variable_scope(generator_scope) as gen_scope: + generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) + generated_data = generator_fn(generator_inputs) + with variable_scope.variable_scope(discriminator_scope) as dis_scope: + (discriminator_gen_outputs, discriminator_gen_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(generated_data, generator_inputs)) + with variable_scope.variable_scope(dis_scope, reuse=True): + real_data = ops.convert_to_tensor(real_data) + (discriminator_real_outputs, discriminator_real_classification_logits + ) = _validate_acgan_discriminator_outputs( + discriminator_fn(real_data, generator_inputs)) + if check_shapes: + if not generated_data.shape.is_compatible_with(real_data.shape): + raise ValueError( + 'Generator output shape (%s) must be the same shape as real data ' + '(%s).' % (generated_data.shape, real_data.shape)) + + # Get model-specific variables. + generator_variables = variables_lib.get_trainable_variables(gen_scope) + discriminator_variables = variables_lib.get_trainable_variables( + dis_scope) + + return namedtuples.ACGANModel( + generator_inputs, generated_data, generator_variables, gen_scope, + generator_fn, real_data, discriminator_real_outputs, + discriminator_gen_outputs, discriminator_variables, dis_scope, + discriminator_fn, one_hot_labels, + discriminator_real_classification_logits, + discriminator_gen_classification_logits) + + +def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): + if isinstance(aux_loss_weight, ops.Tensor): + aux_loss_weight.shape.assert_is_compatible_with([]) + with ops.control_dependencies( + [check_ops.assert_greater_equal(aux_loss_weight, 0.0)]): + aux_loss_weight = array_ops.identity(aux_loss_weight) + elif aux_loss_weight is not None and aux_loss_weight < 0: + raise ValueError('`%s` must be greater than 0. Instead, was %s' % + (name, aux_loss_weight)) + return aux_loss_weight + + +def _use_aux_loss(aux_loss_weight): + if aux_loss_weight is not None: + if not isinstance(aux_loss_weight, ops.Tensor): + return aux_loss_weight > 0 + else: + return True + else: + return False + + +def gan_loss( + # GANModel. + model, + # Loss functions. + generator_loss_fn=tfgan_losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, + # Auxiliary losses. + gradient_penalty_weight=None, + gradient_penalty_epsilon=1e-10, + mutual_information_penalty_weight=None, + aux_cond_generator_weight=None, + aux_cond_discriminator_weight=None, + # Options. + add_summaries=True): + """Returns losses necessary to train generator and discriminator. + + Args: + model: A GANModel tuple. + generator_loss_fn: The loss function on the generator. Takes a GANModel + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + GANModel tuple. + gradient_penalty_weight: If not `None`, must be a non-negative Python number + or Tensor indicating how much to weight the gradient penalty. See + https://arxiv.org/pdf/1704.00028.pdf for more details. + gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the + small positive value used by the gradient penalty function for numerical + stability. Note some applications will need to increase this value to + avoid NaNs. + mutual_information_penalty_weight: If not `None`, must be a non-negative + Python number or Tensor indicating how much to weight the mutual + information penalty. See https://arxiv.org/abs/1606.03657 for more + details. + aux_cond_generator_weight: If not None: add a classification loss as in + https://arxiv.org/abs/1610.09585 + aux_cond_discriminator_weight: If not None: add a classification loss as in + https://arxiv.org/abs/1610.09585 + add_summaries: Whether or not to add summaries for the losses. + + Returns: + A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes + regularization losses. + + Raises: + ValueError: If any of the auxiliary loss weights is provided and negative. + ValueError: If `mutual_information_penalty_weight` is provided, but the + `model` isn't an `InfoGANModel`. + """ + # Validate arguments. + gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight, + 'gradient_penalty_weight') + mutual_information_penalty_weight = _validate_aux_loss_weight( + mutual_information_penalty_weight, 'infogan_weight') + aux_cond_generator_weight = _validate_aux_loss_weight( + aux_cond_generator_weight, 'aux_cond_generator_weight') + aux_cond_discriminator_weight = _validate_aux_loss_weight( + aux_cond_discriminator_weight, 'aux_cond_discriminator_weight') + + # Verify configuration for mutual information penalty + if (_use_aux_loss(mutual_information_penalty_weight) and + not isinstance(model, namedtuples.InfoGANModel)): + raise ValueError( + 'When `mutual_information_penalty_weight` is provided, `model` must be ' + 'an `InfoGANModel`. Instead, was %s.' % type(model)) + + # Verify configuration for mutual auxiliary condition loss (ACGAN). + if ((_use_aux_loss(aux_cond_generator_weight) or + _use_aux_loss(aux_cond_discriminator_weight)) and + not isinstance(model, namedtuples.ACGANModel)): + raise ValueError( + 'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` ' + 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % + type(model)) + + # Create standard losses. + gen_loss = generator_loss_fn(model, add_summaries=add_summaries) + dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries) + + # Add optional extra losses. + if _use_aux_loss(gradient_penalty_weight): + gp_loss = tfgan_losses.wasserstein_gradient_penalty( + model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries) + dis_loss += gradient_penalty_weight * gp_loss + if _use_aux_loss(mutual_information_penalty_weight): + info_loss = tfgan_losses.mutual_information_penalty( + model, add_summaries=add_summaries) + dis_loss += mutual_information_penalty_weight * info_loss + gen_loss += mutual_information_penalty_weight * info_loss + if _use_aux_loss(aux_cond_generator_weight): + ac_gen_loss = tfgan_losses.acgan_generator_loss( + model, add_summaries=add_summaries) + gen_loss += aux_cond_generator_weight * ac_gen_loss + if _use_aux_loss(aux_cond_discriminator_weight): + ac_disc_loss = tfgan_losses.acgan_discriminator_loss( + model, add_summaries=add_summaries) + dis_loss += aux_cond_discriminator_weight * ac_disc_loss + # Gathers auxilliary losses. + if model.generator_scope: + gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) + else: + gen_reg_loss = 0 + if model.discriminator_scope: + dis_reg_loss = losses.get_regularization_loss( + model.discriminator_scope.name) + else: + dis_reg_loss = 0 + + return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss) + + +def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): + """Gets generator and discriminator update ops. + + Args: + kwargs: A dictionary of kwargs to be passed to `create_train_op`. + `update_ops` is removed, if present. + gen_scope: A scope for the generator. + dis_scope: A scope for the discriminator. + check_for_unused_ops: A Python bool. If `True`, throw Exception if there are + unused update ops. + + Returns: + A 2-tuple of (generator update ops, discriminator train ops). + + Raises: + ValueError: If there are update ops outside of the generator or + discriminator scopes. + """ + if 'update_ops' in kwargs: + update_ops = set(kwargs['update_ops']) + del kwargs['update_ops'] + else: + update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) + + all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope)) + all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope)) + + if check_for_unused_ops: + unused_ops = update_ops - all_gen_ops - all_dis_ops + if unused_ops: + raise ValueError('There are unused update ops: %s' % unused_ops) + + gen_update_ops = list(all_gen_ops & update_ops) + dis_update_ops = list(all_dis_ops & update_ops) + + return gen_update_ops, dis_update_ops + + +def gan_train_ops( + model, # GANModel + loss, # GANLoss + generator_optimizer, + discriminator_optimizer, + # Optional check flags. + check_for_unused_update_ops=True, + # Optional args to pass directly to the `create_train_op`. + **kwargs): + """Returns GAN train ops. + + The highest-level call in TFGAN. It is composed of functions that can also + be called, should a user require more control over some part of the GAN + training process. + + Args: + model: A GANModel. + loss: A GANLoss. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: The optimizer for the discriminator updates. + check_for_unused_update_ops: If `True`, throws an exception if there are + update ops outside of the generator or discriminator scopes. + **kwargs: Keyword args to pass directly to + `training.create_train_op` for both the generator and + discriminator train op. + + Returns: + A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can + be used to train a generator/discriminator pair. + """ + # Create global step increment op. + global_step = training_util.get_or_create_global_step() + global_step_inc = global_step.assign_add(1) + + # Get generator and discriminator update ops. We split them so that update + # ops aren't accidentally run multiple times. For now, throw an error if + # there are update ops that aren't associated with either the generator or + # the discriminator. Might modify the `kwargs` dictionary. + gen_update_ops, dis_update_ops = _get_update_ops( + kwargs, model.generator_scope.name, model.discriminator_scope.name, + check_for_unused_update_ops) + + generator_global_step = None + if isinstance(generator_optimizer, + sync_replicas_optimizer.SyncReplicasOptimizer): + # TODO(joelshor): Figure out a way to get this work without including the + # dummy global step in the checkpoint. + # WARNING: Making this variable a local variable causes sync replicas to + # hang forever. + generator_global_step = variable_scope.get_variable( + 'dummy_global_step_generator', + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + gen_update_ops += [generator_global_step.assign(global_step)] + with ops.name_scope('generator_train'): + gen_train_op = training.create_train_op( + total_loss=loss.generator_loss, + optimizer=generator_optimizer, + variables_to_train=model.generator_variables, + global_step=generator_global_step, + update_ops=gen_update_ops, + **kwargs) + + discriminator_global_step = None + if isinstance(discriminator_optimizer, + sync_replicas_optimizer.SyncReplicasOptimizer): + # See comment above `generator_global_step`. + discriminator_global_step = variable_scope.get_variable( + 'dummy_global_step_discriminator', + shape=[], + dtype=dtypes.int64, + initializer=init_ops.zeros_initializer(), + trainable=False, + collections=[ops.GraphKeys.GLOBAL_VARIABLES]) + dis_update_ops += [discriminator_global_step.assign(global_step)] + with ops.name_scope('discriminator_train'): + disc_train_op = training.create_train_op( + total_loss=loss.discriminator_loss, + optimizer=discriminator_optimizer, + variables_to_train=model.discriminator_variables, + global_step=discriminator_global_step, + update_ops=dis_update_ops, + **kwargs) + + return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc) + + +# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive +# Image Compression` (https://arxiv.org/abs/1705.05823) +class RunTrainOpsHook(session_run_hook.SessionRunHook): + """A hook to run train ops a fixed number of times.""" + + def __init__(self, train_ops, train_steps): + """Run train ops a certain number of times. + + Args: + train_ops: A train op or iterable of train ops to run. + train_steps: The number of times to run the op(s). + """ + if not isinstance(train_ops, (list, tuple)): + train_ops = [train_ops] + self._train_ops = train_ops + self._train_steps = train_steps + + def before_run(self, run_context): + for _ in range(self._train_steps): + run_context.session.run(self._train_ops) + + +def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): + """Returns a hooks function for sequential GAN training. + + Args: + train_steps: A `GANTrainSteps` tuple that determines how many generator + and discriminator training steps to take. + + Returns: + A function that takes a GANTrainOps tuple and returns a list of hooks. + """ + def get_hooks(train_ops): + generator_hook = RunTrainOpsHook(train_ops.generator_train_op, + train_steps.generator_train_steps) + discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, + train_steps.discriminator_train_steps) + return [generator_hook, discriminator_hook] + return get_hooks + + +def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): + """Returns a hooks function for sequential GAN training. + + When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON + ALL OPTIMIZERS TO AVOID RACE CONDITIONS. + + The order of steps taken is: + 1) Combined generator and discriminator steps + 2) Generator only steps, if any remain + 3) Discriminator only steps, if any remain + + **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates + for the generator and discriminator simultaneously whenever possible. This + reduces the number of `tf.Session` calls, and can also change the training + semantics. + + To illustrate the difference look at the following example: + + `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause + `get_sequential_train_hooks` to make 8 session calls: + 1) 3 generator steps + 2) 5 discriminator steps + + In contrast, `get_joint_train_steps` will make 5 session calls: + 1) 3 generator + discriminator steps + 2) 2 discriminator steps + + Args: + train_steps: A `GANTrainSteps` tuple that determines how many generator + and discriminator training steps to take. + + Returns: + A function that takes a GANTrainOps tuple and returns a list of hooks. + """ + g_steps = train_steps.generator_train_steps + d_steps = train_steps.discriminator_train_steps + # Get the number of each type of step that should be run. + num_d_and_g_steps = min(g_steps, d_steps) + num_g_steps = g_steps - num_d_and_g_steps + num_d_steps = d_steps - num_d_and_g_steps + + def get_hooks(train_ops): + g_op = train_ops.generator_train_op + d_op = train_ops.discriminator_train_op + + joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps) + g_hook = RunTrainOpsHook(g_op, num_g_steps) + d_hook = RunTrainOpsHook(d_op, num_d_steps) + + return [joint_hook, g_hook, d_hook] + return get_hooks + + +# TODO(joelshor): This function currently returns the global step. Find a +# good way for it to return the generator, discriminator, and final losses. +def gan_train( + train_ops, + logdir, + get_hooks_fn=get_sequential_train_hooks(), + master='', + is_chief=True, + scaffold=None, + hooks=None, + chief_only_hooks=None, + save_checkpoint_secs=600, + save_summaries_steps=100, + config=None): + """A wrapper around `contrib.training.train` that uses GAN hooks. + + Args: + train_ops: A GANTrainOps named tuple. + logdir: The directory where the graph and checkpoints are saved. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + master: The URL of the master. + is_chief: Specifies whether or not the training is being run by the primary + replica during replica training. + scaffold: An tf.train.Scaffold instance. + hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the + training loop. + chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run + inside the training loop for the chief trainer only. + save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved + using a default checkpoint saver. If `save_checkpoint_secs` is set to + `None`, then the default checkpoint saver isn't used. + save_summaries_steps: The frequency, in number of global steps, that the + summaries are written to disk using a default summary saver. If + `save_summaries_steps` is set to `None`, then the default summary saver + isn't used. + config: An instance of `tf.ConfigProto`. + + Returns: + Output of the call to `training.train`. + """ + new_hooks = get_hooks_fn(train_ops) + if hooks is not None: + hooks = list(hooks) + list(new_hooks) + else: + hooks = new_hooks + return training.train( + train_ops.global_step_inc_op, + logdir, + master=master, + is_chief=is_chief, + scaffold=scaffold, + hooks=hooks, + chief_only_hooks=chief_only_hooks, + save_checkpoint_secs=save_checkpoint_secs, + save_summaries_steps=save_summaries_steps, + config=config) + + +def get_sequential_train_steps( + train_steps=namedtuples.GANTrainSteps(1, 1)): + """Returns a thin wrapper around slim.learning.train_step, for GANs. + + This function is to provide support for the Supervisor. For new code, please + use `MonitoredSession` and `get_sequential_train_hooks`. + + Args: + train_steps: A `GANTrainSteps` tuple that determines how many generator + and discriminator training steps to take. + + Returns: + A function that can be used for `train_step_fn` for GANs. + """ + + def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs): + """A thin wrapper around slim.learning.train_step, for GANs. + + Args: + sess: A Tensorflow session. + train_ops: A GANTrainOps tuple of train ops to run. + global_step: The global step. + train_step_kwargs: Dictionary controlling `train_step` behavior. + + Returns: + A scalar final loss and a bool whether or not the train loop should stop. + """ + # Only run `should_stop` at the end, if required. Make a local copy of + # `train_step_kwargs`, if necessary, so as not to modify the caller's + # dictionary. + should_stop_op, train_kwargs = None, train_step_kwargs + if 'should_stop' in train_step_kwargs: + should_stop_op = train_step_kwargs['should_stop'] + train_kwargs = train_step_kwargs.copy() + del train_kwargs['should_stop'] + + # Run generator training steps. + gen_loss = 0 + for _ in range(train_steps.generator_train_steps): + cur_gen_loss, _ = slim_learning.train_step( + sess, train_ops.generator_train_op, global_step, train_kwargs) + gen_loss += cur_gen_loss + + # Run discriminator training steps. + dis_loss = 0 + for _ in range(train_steps.discriminator_train_steps): + cur_dis_loss, _ = slim_learning.train_step( + sess, train_ops.discriminator_train_op, global_step, train_kwargs) + dis_loss += cur_dis_loss + + sess.run(train_ops.global_step_inc_op) + + # Run the `should_stop` op after the global step has been incremented, so + # that the `should_stop` aligns with the proper `global_step` count. + if should_stop_op is not None: + should_stop = sess.run(should_stop_op) + else: + should_stop = False + + return gen_loss + dis_loss, should_stop + + return sequential_train_steps diff --git a/tensorflow/contrib/gan/python/train_test.py b/tensorflow/contrib/gan/python/train_test.py new file mode 100644 index 00000000000..83b763806cd --- /dev/null +++ b/tensorflow/contrib/gan/python/train_test.py @@ -0,0 +1,745 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for gan.python.train.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import numpy as np + +from tensorflow.contrib.framework.python.ops import variables as variables_lib +from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python import train +from tensorflow.contrib.slim.python.slim import learning as slim_learning +from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.framework import random_seed +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.ops import variables +from tensorflow.python.ops.distributions import categorical +from tensorflow.python.platform import test +from tensorflow.python.training import basic_session_run_hooks +from tensorflow.python.training import coordinator +from tensorflow.python.training import gradient_descent +from tensorflow.python.training import sync_replicas_optimizer +from tensorflow.python.training import training_util + + +def generator_model(inputs): + return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs + + +class Generator(object): + + def __call__(self, inputs): + return generator_model(inputs) + + +def infogan_generator_model(inputs): + return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs[0] + + +class InfoGANGenerator(object): + + def __call__(self, inputs): + return infogan_generator_model(inputs) + + +def discriminator_model(inputs, _): + return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs + + +class Discriminator(object): + + def __call__(self, inputs, _): + return discriminator_model(inputs, _) + + +def infogan_discriminator_model(inputs, _): + return (variable_scope.get_variable('dummy_d', initializer=2.0) * inputs, + [categorical.Categorical([1.0])]) + + +class InfoGANDiscriminator(object): + + def __call__(self, inputs, _): + return infogan_discriminator_model(inputs, _) + + +def acgan_discriminator_model(inputs, _, num_classes=10): + return (discriminator_model(inputs, _), array_ops.one_hot( + # TODO(haeusser): infer batch size from input + random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32), + num_classes)) + + +class ACGANDiscriminator(object): + + def __call__(self, inputs, _, num_classes=10): + return (discriminator_model(inputs, _), array_ops.one_hot( + # TODO(haeusser): infer batch size from input + random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32), + num_classes)) + + +def get_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + pass + with variable_scope.variable_scope('discriminator') as dis_scope: + pass + return namedtuples.GANModel( + generator_inputs=None, + generated_data=None, + generator_variables=None, + generator_scope=gen_scope, + generator_fn=generator_model, + real_data=array_ops.ones([1, 2, 3]), + discriminator_real_outputs=array_ops.ones([1, 2, 3]), + discriminator_gen_outputs=array_ops.ones([1, 2, 3]), + discriminator_variables=None, + discriminator_scope=dis_scope, + discriminator_fn=discriminator_model) + + +def get_callable_gan_model(): + ganmodel = get_gan_model() + return ganmodel._replace( + generator_fn=Generator(), + discriminator_fn=Discriminator()) + + +def create_gan_model(): + return train.gan_model( + generator_model, + discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2])) + + +def create_callable_gan_model(): + return train.gan_model( + Generator(), + Discriminator(), + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2])) + + +def get_infogan_model(): + return namedtuples.InfoGANModel( + *get_gan_model(), + structured_generator_inputs=[constant_op.constant(0)], + predicted_distributions=[categorical.Categorical([1.0])]) + + +def get_callable_infogan_model(): + return namedtuples.InfoGANModel( + *get_callable_gan_model(), + structured_generator_inputs=[constant_op.constant(0)], + predicted_distributions=[categorical.Categorical([1.0])]) + + +def create_infogan_model(): + return train.infogan_model( + infogan_generator_model, + infogan_discriminator_model, + real_data=array_ops.zeros([1, 2]), + unstructured_generator_inputs=[], + structured_generator_inputs=[random_ops.random_normal([1, 2])]) + + +def create_callable_infogan_model(): + return train.infogan_model( + InfoGANGenerator(), + InfoGANDiscriminator(), + real_data=array_ops.zeros([1, 2]), + unstructured_generator_inputs=[], + structured_generator_inputs=[random_ops.random_normal([1, 2])]) + + +def get_acgan_model(): + return namedtuples.ACGANModel( + *get_gan_model(), + one_hot_labels=array_ops.one_hot([0, 1, 2], 10), + discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10), + discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10)) + + +def get_callable_acgan_model(): + return namedtuples.ACGANModel( + *get_callable_gan_model(), + one_hot_labels=array_ops.one_hot([0, 1, 2], 10), + discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10), + discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10)) + + +def create_acgan_model(): + return train.acgan_model( + generator_model, + acgan_discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2]), + one_hot_labels=array_ops.one_hot([0, 1, 2], 10)) + + +def create_callable_acgan_model(): + return train.acgan_model( + Generator(), + ACGANDiscriminator(), + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2]), + one_hot_labels=array_ops.one_hot([0, 1, 2], 10)) + + +def get_sync_optimizer(): + return sync_replicas_optimizer.SyncReplicasOptimizer( + gradient_descent.GradientDescentOptimizer(learning_rate=1.0), + replicas_to_aggregate=1) + + +class GANModelTest(test.TestCase): + """Tests for `gan_model`.""" + + def _test_output_type_helper(self, create_fn, tuple_type): + self.assertTrue(isinstance(create_fn(), tuple_type)) + + def test_output_type_gan(self): + self._test_output_type_helper(get_gan_model, namedtuples.GANModel) + + def test_output_type_callable_gan(self): + self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel) + + def test_output_type_infogan(self): + self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel) + + def test_output_type_callable_infogan(self): + self._test_output_type_helper( + get_callable_infogan_model, namedtuples.InfoGANModel) + + def test_output_type_acgan(self): + self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel) + + def test_output_type_callable_acgan(self): + self._test_output_type_helper( + get_callable_acgan_model, namedtuples.ACGANModel) + + def test_no_shape_check(self): + def dummy_generator_model(_): + return (None, None) + def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument + return 1 + with self.assertRaisesRegexp(AttributeError, 'object has no attribute'): + train.gan_model( + dummy_generator_model, + dummy_discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=array_ops.zeros([1]), + check_shapes=True) + train.gan_model( + dummy_generator_model, + dummy_discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=array_ops.zeros([1]), + check_shapes=False) + + +class GANLossTest(test.TestCase): + """Tests for `gan_loss`.""" + + # Test output type. + def _test_output_type_helper(self, get_gan_model_fn): + loss = train.gan_loss(get_gan_model_fn(), add_summaries=True) + self.assertTrue(isinstance(loss, namedtuples.GANLoss)) + self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0) + + def test_output_type_gan(self): + self._test_output_type_helper(get_gan_model) + + def test_output_type_callable_gan(self): + self._test_output_type_helper(get_callable_gan_model) + + def test_output_type_infogan(self): + self._test_output_type_helper(get_infogan_model) + + def test_output_type_callable_infogan(self): + self._test_output_type_helper(get_callable_infogan_model) + + def test_output_type_acgan(self): + self._test_output_type_helper(get_acgan_model) + + def test_output_type_callable_acgan(self): + self._test_output_type_helper(get_callable_acgan_model) + + # Test gradient penalty option. + def _test_grad_penalty_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.gan_loss(model) + loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0) + self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss)) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + loss_gen_np, loss_gen_gp_np = sess.run( + [loss.generator_loss, loss_gp.generator_loss]) + loss_dis_np, loss_dis_gp_np = sess.run( + [loss.discriminator_loss, loss_gp.discriminator_loss]) + + self.assertEqual(loss_gen_np, loss_gen_gp_np) + self.assertTrue(loss_dis_np < loss_dis_gp_np) + + def test_grad_penalty_gan(self): + self._test_grad_penalty_helper(create_gan_model) + + def test_grad_penalty_callable_gan(self): + self._test_grad_penalty_helper(create_callable_gan_model) + + def test_grad_penalty_infogan(self): + self._test_grad_penalty_helper(create_infogan_model) + + def test_grad_penalty_callable_infogan(self): + self._test_grad_penalty_helper(create_callable_infogan_model) + + def test_grad_penalty_acgan(self): + self._test_grad_penalty_helper(create_acgan_model) + + def test_grad_penalty_callable_acgan(self): + self._test_grad_penalty_helper(create_callable_acgan_model) + + # Test mutual information penalty option. + def _test_mutual_info_penalty_helper(self, create_gan_model_fn): + train.gan_loss(create_gan_model_fn(), + mutual_information_penalty_weight=constant_op.constant(1.0)) + + def test_mutual_info_penalty_infogan(self): + self._test_mutual_info_penalty_helper(get_infogan_model) + + def test_mutual_info_penalty_callable_infogan(self): + self._test_mutual_info_penalty_helper(get_callable_infogan_model) + + # Test regularization loss. + def _test_regularization_helper(self, get_gan_model_fn): + # Evaluate losses without regularization. + no_reg_loss = train.gan_loss(get_gan_model_fn()) + with self.test_session(use_gpu=True): + no_reg_loss_gen_np = no_reg_loss.generator_loss.eval() + no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval() + + with ops.name_scope(get_gan_model_fn().generator_scope.name): + ops.add_to_collection( + ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0)) + with ops.name_scope(get_gan_model_fn().discriminator_scope.name): + ops.add_to_collection( + ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0)) + + # Check that losses now include the correct regularization values. + reg_loss = train.gan_loss(get_gan_model_fn()) + with self.test_session(use_gpu=True): + reg_loss_gen_np = reg_loss.generator_loss.eval() + reg_loss_dis_np = reg_loss.discriminator_loss.eval() + + self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np) + self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np) + + def test_regularization_gan(self): + self._test_regularization_helper(get_gan_model) + + def test_regularization_callable_gan(self): + self._test_regularization_helper(get_callable_gan_model) + + def test_regularization_infogan(self): + self._test_regularization_helper(get_infogan_model) + + def test_regularization_callable_infogan(self): + self._test_regularization_helper(get_callable_infogan_model) + + def test_regularization_acgan(self): + self._test_regularization_helper(get_acgan_model) + + def test_regularization_callable_acgan(self): + self._test_regularization_helper(get_callable_acgan_model) + + # Test that ACGan models work. + def _test_acgan_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.gan_loss(model) + loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0) + loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0) + self.assertTrue(isinstance(loss, namedtuples.GANLoss)) + self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss)) + self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss)) + + # Check values. + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run( + [loss.generator_loss, + loss_ac_gen.generator_loss, + loss_ac_dis.generator_loss]) + loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run( + [loss.discriminator_loss, + loss_ac_gen.discriminator_loss, + loss_ac_dis.discriminator_loss]) + + self.assertTrue(loss_gen_np < loss_dis_np) + self.assertTrue(np.isscalar(loss_ac_gen_gen_np)) + self.assertTrue(np.isscalar(loss_ac_dis_gen_np)) + self.assertTrue(np.isscalar(loss_ac_gen_dis_np)) + self.assertTrue(np.isscalar(loss_ac_dis_dis_np)) + + def test_acgan(self): + self._test_acgan_helper(create_acgan_model) + + def test_callable_acgan(self): + self._test_acgan_helper(create_callable_acgan_model) + + def test_doesnt_crash_when_in_nested_scope(self): + with variable_scope.variable_scope('outer_scope'): + gan_model = train.gan_model( + generator_model, + discriminator_model, + real_data=array_ops.zeros([1, 2]), + generator_inputs=random_ops.random_normal([1, 2])) + + # This should work inside a scope. + train.gan_loss(gan_model, gradient_penalty_weight=1.0) + + # This should also work outside a scope. + train.gan_loss(gan_model, gradient_penalty_weight=1.0) + + +class GANTrainOpsTest(test.TestCase): + """Tests for `gan_train_ops`.""" + + def _test_output_type_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.gan_loss(model) + + g_opt = gradient_descent.GradientDescentOptimizer(1.0) + d_opt = gradient_descent.GradientDescentOptimizer(1.0) + train_ops = train.gan_train_ops( + model, + loss, + g_opt, + d_opt, + summarize_gradients=True, + colocate_gradients_with_ops=True) + + self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) + + def test_output_type_gan(self): + self._test_output_type_helper(create_gan_model) + + def test_output_type_callable_gan(self): + self._test_output_type_helper(create_callable_gan_model) + + def test_output_type_infogan(self): + self._test_output_type_helper(create_infogan_model) + + def test_output_type_callable_infogan(self): + self._test_output_type_helper(create_callable_infogan_model) + + def test_output_type_acgan(self): + self._test_output_type_helper(create_acgan_model) + + def test_output_type_callable_acgan(self): + self._test_output_type_helper(create_callable_acgan_model) + + # TODO(joelshor): Add a test to check that custom update op is run. + def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops): + model = create_gan_model_fn() + loss = train.gan_loss(model) + + # Add generator and discriminator update ops. + with variable_scope.variable_scope(model.generator_scope): + gen_update_count = variable_scope.get_variable('gen_count', initializer=0) + gen_update_op = gen_update_count.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op) + with variable_scope.variable_scope(model.discriminator_scope): + dis_update_count = variable_scope.get_variable('dis_count', initializer=0) + dis_update_op = dis_update_count.assign_add(1) + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op) + + # Add an update op outside the generator and discriminator scopes. + if provide_update_ops: + kwargs = {'update_ops': + [constant_op.constant(1.0), gen_update_op, dis_update_op]} + else: + ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0)) + kwargs = {} + + g_opt = gradient_descent.GradientDescentOptimizer(1.0) + d_opt = gradient_descent.GradientDescentOptimizer(1.0) + + with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'): + train.gan_train_ops(model, loss, g_opt, d_opt, + check_for_unused_update_ops=True, **kwargs) + train_ops = train.gan_train_ops( + model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs) + + with self.test_session(use_gpu=True) as sess: + sess.run(variables.global_variables_initializer()) + self.assertEqual(0, gen_update_count.eval()) + self.assertEqual(0, dis_update_count.eval()) + + train_ops.generator_train_op.eval() + self.assertEqual(1, gen_update_count.eval()) + self.assertEqual(0, dis_update_count.eval()) + + train_ops.discriminator_train_op.eval() + self.assertEqual(1, gen_update_count.eval()) + self.assertEqual(1, dis_update_count.eval()) + + def test_unused_update_ops_gan(self): + self._test_unused_update_ops(create_gan_model, False) + + def test_unused_update_ops_gan_provideupdates(self): + self._test_unused_update_ops(create_gan_model, True) + + def test_unused_update_ops_callable_gan(self): + self._test_unused_update_ops(create_callable_gan_model, False) + + def test_unused_update_ops_callable_gan_provideupdates(self): + self._test_unused_update_ops(create_callable_gan_model, True) + + def test_unused_update_ops_infogan(self): + self._test_unused_update_ops(create_infogan_model, False) + + def test_unused_update_ops_infogan_provideupdates(self): + self._test_unused_update_ops(create_infogan_model, True) + + def test_unused_update_ops_callable_infogan(self): + self._test_unused_update_ops(create_callable_infogan_model, False) + + def test_unused_update_ops_callable_infogan_provideupdates(self): + self._test_unused_update_ops(create_callable_infogan_model, True) + + def test_unused_update_ops_acgan(self): + self._test_unused_update_ops(create_acgan_model, False) + + def test_unused_update_ops_acgan_provideupdates(self): + self._test_unused_update_ops(create_acgan_model, True) + + def test_unused_update_ops_callable_acgan(self): + self._test_unused_update_ops(create_callable_acgan_model, False) + + def test_unused_update_ops_callable_acgan_provideupdates(self): + self._test_unused_update_ops(create_callable_acgan_model, True) + + def _test_sync_replicas_helper(self, create_gan_model_fn): + model = create_gan_model_fn() + loss = train.gan_loss(model) + num_trainable_vars = len(variables_lib.get_trainable_variables()) + + g_opt = get_sync_optimizer() + d_opt = get_sync_optimizer() + train_ops = train.gan_train_ops( + model, + loss, + generator_optimizer=g_opt, + discriminator_optimizer=d_opt) + self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps)) + # No new trainable variables should have been added. + self.assertEqual(num_trainable_vars, + len(variables_lib.get_trainable_variables())) + + g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1) + d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1) + + # Check that update op is run properly. + global_step = training_util.get_or_create_global_step() + with self.test_session(use_gpu=True) as sess: + variables.global_variables_initializer().run() + variables.local_variables_initializer().run() + + g_opt.chief_init_op.run() + d_opt.chief_init_op.run() + + gstep_before = global_step.eval() + + # Start required queue runner for SyncReplicasOptimizer. + coord = coordinator.Coordinator() + g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord) + d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord) + + g_sync_init_op.run() + d_sync_init_op.run() + + train_ops.generator_train_op.eval() + # Check that global step wasn't incremented. + self.assertEqual(gstep_before, global_step.eval()) + + train_ops.discriminator_train_op.eval() + # Check that global step wasn't incremented. + self.assertEqual(gstep_before, global_step.eval()) + + coord.request_stop() + coord.join(g_threads + d_threads) + + def test_sync_replicas_gan(self): + self._test_sync_replicas_helper(create_gan_model) + + def test_sync_replicas_callable_gan(self): + self._test_sync_replicas_helper(create_callable_gan_model) + + def test_sync_replicas_infogan(self): + self._test_sync_replicas_helper(create_infogan_model) + + def test_sync_replicas_callable_infogan(self): + self._test_sync_replicas_helper(create_callable_infogan_model) + + def test_sync_replicas_acgan(self): + self._test_sync_replicas_helper(create_acgan_model) + + def test_sync_replicas_callable_acgan(self): + self._test_sync_replicas_helper(create_callable_acgan_model) + + +class GANTrainTest(test.TestCase): + """Tests for `gan_train`.""" + + def _gan_train_ops(self, generator_add, discriminator_add): + step = training_util.create_global_step() + # Increment the global count every time a train op is run so we can count + # the number of times they're run. + # NOTE: `use_locking=True` is required to avoid race conditions with + # joint training. + train_ops = namedtuples.GANTrainOps( + generator_train_op=step.assign_add(generator_add, use_locking=True), + discriminator_train_op=step.assign_add(discriminator_add, + use_locking=True), + global_step_inc_op=step.assign_add(1)) + return train_ops + + def _test_run_helper(self, create_gan_model_fn): + random_seed.set_random_seed(1234) + model = create_gan_model_fn() + loss = train.gan_loss(model) + + g_opt = gradient_descent.GradientDescentOptimizer(1.0) + d_opt = gradient_descent.GradientDescentOptimizer(1.0) + train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) + + final_step = train.gan_train( + train_ops, + logdir='', + hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) + self.assertTrue(np.isscalar(final_step)) + self.assertEqual(2, final_step) + + def test_run_gan(self): + self._test_run_helper(create_gan_model) + + def test_run_callable_gan(self): + self._test_run_helper(create_callable_gan_model) + + def test_run_infogan(self): + self._test_run_helper(create_infogan_model) + + def test_run_callable_infogan(self): + self._test_run_helper(create_callable_infogan_model) + + def test_run_acgan(self): + self._test_run_helper(create_acgan_model) + + def test_run_callable_acgan(self): + self._test_run_helper(create_callable_acgan_model) + + # Test multiple train steps. + def _test_multiple_steps_helper(self, get_hooks_fn_fn): + train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100) + train_steps = namedtuples.GANTrainSteps( + generator_train_steps=3, + discriminator_train_steps=4) + final_step = train.gan_train( + train_ops, + get_hooks_fn=get_hooks_fn_fn(train_steps), + logdir='', + hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)]) + + self.assertTrue(np.isscalar(final_step)) + self.assertEqual(1 + 3 * 10 + 4 * 100, final_step) + + def test_multiple_steps_seq_train_steps(self): + self._test_multiple_steps_helper(train.get_sequential_train_hooks) + + def test_multiple_steps_efficient_seq_train_steps(self): + self._test_multiple_steps_helper(train.get_joint_train_hooks) + + def test_supervisor_run_gan_model_train_ops_multiple_steps(self): + step = training_util.create_global_step() + train_ops = namedtuples.GANTrainOps( + generator_train_op=constant_op.constant(3.0), + discriminator_train_op=constant_op.constant(2.0), + global_step_inc_op=step.assign_add(1)) + train_steps = namedtuples.GANTrainSteps( + generator_train_steps=3, + discriminator_train_steps=4) + + final_loss = slim_learning.train( + train_op=train_ops, + logdir='', + global_step=step, + number_of_steps=1, + train_step_fn=train.get_sequential_train_steps(train_steps)) + self.assertTrue(np.isscalar(final_loss)) + self.assertEqual(17.0, final_loss) + + +class PatchGANTest(test.TestCase): + """Tests that functions work on PatchGAN style output.""" + + def _test_patchgan_helper(self, create_gan_model_fn): + """Ensure that patch-based discriminators work end-to-end.""" + random_seed.set_random_seed(1234) + model = create_gan_model_fn() + loss = train.gan_loss(model) + + g_opt = gradient_descent.GradientDescentOptimizer(1.0) + d_opt = gradient_descent.GradientDescentOptimizer(1.0) + train_ops = train.gan_train_ops(model, loss, g_opt, d_opt) + + final_step = train.gan_train( + train_ops, + logdir='', + hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)]) + self.assertTrue(np.isscalar(final_step)) + self.assertEqual(2, final_step) + + def test_patchgan_gan(self): + self._test_patchgan_helper(create_gan_model) + + def test_patchgan_callable_gan(self): + self._test_patchgan_helper(create_callable_gan_model) + + def test_patchgan_infogan(self): + self._test_patchgan_helper(create_infogan_model) + + def test_patchgan_callable_infogan(self): + self._test_patchgan_helper(create_callable_infogan_model) + + def test_patchgan_acgan(self): + self._test_patchgan_helper(create_acgan_model) + + def test_patchgan_callable_acgan(self): + self._test_patchgan_helper(create_callable_acgan_model) + + +if __name__ == '__main__': + test.main()