Migrate core TFGAN functions to opensource.
PiperOrigin-RevId: 168391923
This commit is contained in:
parent
bc6b60f1bc
commit
f63aa7f49f
@ -16,6 +16,60 @@ py_library(
|
||||
deps = [
|
||||
":features",
|
||||
":losses",
|
||||
":namedtuples",
|
||||
":train",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "namedtuples",
|
||||
srcs = ["python/namedtuples.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "train",
|
||||
srcs = ["python/train.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":losses",
|
||||
":namedtuples",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/slim:learning",
|
||||
"//tensorflow/contrib/training:training_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:check_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//tensorflow/python/ops/losses",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "train_test",
|
||||
srcs = ["python/train_test.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":namedtuples",
|
||||
":train",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/contrib/slim:learning",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:constant_op",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:random_ops",
|
||||
"//tensorflow/python:random_seed",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python:variables",
|
||||
"//tensorflow/python/ops/distributions",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,4 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
# Copyright 2017 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -21,7 +21,20 @@ from __future__ import print_function
|
||||
# Collapse TFGAN into a tiered namespace.
|
||||
from tensorflow.contrib.gan.python import features
|
||||
from tensorflow.contrib.gan.python import losses
|
||||
from tensorflow.contrib.gan.python import namedtuples
|
||||
from tensorflow.contrib.gan.python import train
|
||||
|
||||
del absolute_import
|
||||
del division
|
||||
del print_function
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.gan.python.namedtuples import *
|
||||
from tensorflow.contrib.gan.python.train import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'features',
|
||||
'losses',
|
||||
]
|
||||
_allowed_symbols += train.__all__
|
||||
_allowed_symbols += namedtuples.__all__
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
||||
|
149
tensorflow/contrib/gan/python/namedtuples.py
Normal file
149
tensorflow/contrib/gan/python/namedtuples.py
Normal file
@ -0,0 +1,149 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Named tuples for TFGAN."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import collections
|
||||
|
||||
|
||||
__all__ = [
|
||||
'GANModel',
|
||||
'InfoGANModel',
|
||||
'ACGANModel',
|
||||
'GANLoss',
|
||||
'GANTrainOps',
|
||||
'GANTrainSteps',
|
||||
]
|
||||
|
||||
|
||||
class GANModel(
|
||||
collections.namedtuple('GANModel', (
|
||||
'generator_inputs',
|
||||
'generated_data',
|
||||
'generator_variables',
|
||||
'generator_scope',
|
||||
'generator_fn',
|
||||
'real_data',
|
||||
'discriminator_real_outputs',
|
||||
'discriminator_gen_outputs',
|
||||
'discriminator_variables',
|
||||
'discriminator_scope',
|
||||
'discriminator_fn',
|
||||
))):
|
||||
"""A GANModel contains all the pieces needed for GAN training.
|
||||
|
||||
Generative Adversarial Networks (https://arxiv.org/abs/1406.2661) attempt
|
||||
to create an implicit generative model of data by solving a two agent game.
|
||||
The generator generates candidate examples that are supposed to match the
|
||||
data distribution, and the discriminator aims to tell the real examples
|
||||
apart from the generated samples.
|
||||
|
||||
Args:
|
||||
generator_inputs: The random noise source that acts as input to the
|
||||
generator.
|
||||
generated_data: The generated output data of the GAN.
|
||||
generator_variables: A list of all generator variables.
|
||||
generator_scope: Variable scope all generator variables live in.
|
||||
generator_fn: The generator function.
|
||||
real_data: A tensor or real data.
|
||||
discriminator_real_outputs: The discriminator's output on real data.
|
||||
discriminator_gen_outputs: The discriminator's output on generated data.
|
||||
discriminator_variables: A list of all discriminator variables.
|
||||
discriminator_scope: Variable scope all discriminator variables live in.
|
||||
discriminator_fn: The discriminator function.
|
||||
"""
|
||||
|
||||
|
||||
# TODO(joelshor): Have this class inherit from `GANModel`.
|
||||
class InfoGANModel(
|
||||
collections.namedtuple('InfoGANModel', GANModel._fields + (
|
||||
'structured_generator_inputs',
|
||||
'predicted_distributions',
|
||||
))):
|
||||
"""An InfoGANModel contains all the pieces needed for InfoGAN training.
|
||||
|
||||
See https://arxiv.org/abs/1606.03657 for more details.
|
||||
|
||||
Args:
|
||||
structured_generator_inputs: A list of Tensors representing the random noise
|
||||
that must have high mutual information with the generator output. List
|
||||
length should match `predicted_distributions`.
|
||||
predicted_distributions: A list of tf.Distributions. Predicted by the
|
||||
recognizer, and used to evaluate the likelihood of the structured noise.
|
||||
List length should match `structured_generator_inputs`.
|
||||
"""
|
||||
|
||||
|
||||
class ACGANModel(
|
||||
collections.namedtuple('ACGANModel', GANModel._fields +
|
||||
('one_hot_labels',
|
||||
'discriminator_real_classification_logits',
|
||||
'discriminator_gen_classification_logits',))):
|
||||
"""An ACGANModel contains all the pieces needed for ACGAN training.
|
||||
|
||||
See https://arxiv.org/abs/1610.09585 for more details.
|
||||
|
||||
Args:
|
||||
one_hot_labels: A Tensor holding one-hot-labels for the batch.
|
||||
discriminator_real_classification_logits: Classification logits for real
|
||||
data.
|
||||
discriminator_gen_classification_logits: Classification logits for generated
|
||||
data.
|
||||
"""
|
||||
|
||||
|
||||
class GANLoss(
|
||||
collections.namedtuple('GANLoss', (
|
||||
'generator_loss',
|
||||
'discriminator_loss'
|
||||
))):
|
||||
"""GANLoss contains the generator and discriminator losses.
|
||||
|
||||
Args:
|
||||
generator_loss: A tensor for the generator loss..
|
||||
discriminator_loss: A tensor for the discriminator loss.
|
||||
"""
|
||||
|
||||
|
||||
class GANTrainOps(
|
||||
collections.namedtuple('GANTrainOps', (
|
||||
'generator_train_op',
|
||||
'discriminator_train_op',
|
||||
'global_step_inc_op'
|
||||
))):
|
||||
"""GANTrainOps contains the training ops.
|
||||
|
||||
Args:
|
||||
generator_train_op: Op that performs a generator update step.
|
||||
discriminator_train_op: Op that performs a discriminator update step.
|
||||
global_step_inc_op: Op that increments the shared global step.
|
||||
"""
|
||||
|
||||
|
||||
class GANTrainSteps(
|
||||
collections.namedtuple('GANTrainSteps', (
|
||||
'generator_train_steps',
|
||||
'discriminator_train_steps'
|
||||
))):
|
||||
"""Contains configuration for the GAN Training.
|
||||
|
||||
Args:
|
||||
generator_train_steps: Number of generator steps to take in each GAN step.
|
||||
discriminator_train_steps: Number of discriminator steps to take in each GAN
|
||||
step.
|
||||
"""
|
804
tensorflow/contrib/gan/python/train.py
Normal file
804
tensorflow/contrib/gan/python/train.py
Normal file
@ -0,0 +1,804 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""The TFGAN project provides a lightweight GAN training/testing framework.
|
||||
|
||||
See examples in `tensorflow_models` for details on how to use.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as variables_lib
|
||||
from tensorflow.contrib.gan.python import losses as tfgan_losses
|
||||
from tensorflow.contrib.gan.python import namedtuples
|
||||
from tensorflow.contrib.slim.python.slim import learning as slim_learning
|
||||
from tensorflow.contrib.training.python.training import training
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import check_ops
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops.distributions import distribution as ds
|
||||
from tensorflow.python.ops.losses import losses
|
||||
from tensorflow.python.training import session_run_hook
|
||||
from tensorflow.python.training import sync_replicas_optimizer
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
__all__ = [
|
||||
'gan_model',
|
||||
'infogan_model',
|
||||
'acgan_model',
|
||||
'gan_loss',
|
||||
'gan_train_ops',
|
||||
'gan_train',
|
||||
'get_sequential_train_hooks',
|
||||
'get_joint_train_hooks',
|
||||
'get_sequential_train_steps',
|
||||
]
|
||||
|
||||
|
||||
def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
|
||||
"""Convert input, list of inputs, or dictionary of inputs to Tensors."""
|
||||
if isinstance(tensor_or_l_or_d, (list, tuple)):
|
||||
return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
|
||||
elif isinstance(tensor_or_l_or_d, dict):
|
||||
return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
|
||||
else:
|
||||
return ops.convert_to_tensor(tensor_or_l_or_d)
|
||||
|
||||
|
||||
def gan_model(
|
||||
# Lambdas defining models.
|
||||
generator_fn,
|
||||
discriminator_fn,
|
||||
# Real data and conditioning.
|
||||
real_data,
|
||||
generator_inputs,
|
||||
# Optional scopes.
|
||||
generator_scope='Generator',
|
||||
discriminator_scope='Discriminator',
|
||||
# Options.
|
||||
check_shapes=True):
|
||||
"""Returns GAN model outputs and variables.
|
||||
|
||||
Args:
|
||||
generator_fn: A python lambda that takes `generator_inputs` as inputs and
|
||||
returns the outputs of the GAN generator.
|
||||
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
||||
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
|
||||
real_data: A Tensor representing the real data.
|
||||
generator_inputs: A Tensor or list of Tensors to the generator. In the
|
||||
vanilla GAN case, this might be a single noise Tensor. In the conditional
|
||||
GAN case, this might be the generator's conditioning.
|
||||
generator_scope: Optional generator variable scope. Useful if you want to
|
||||
reuse a subgraph that has already been created.
|
||||
discriminator_scope: Optional discriminator variable scope. Useful if you
|
||||
want to reuse a subgraph that has already been created.
|
||||
check_shapes: If `True`, check that generator produces Tensors that are the
|
||||
same shape as real data. Otherwise, skip this check.
|
||||
|
||||
Returns:
|
||||
A GANModel namedtuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If the generator outputs a Tensor that isn't the same shape as
|
||||
`real_data`.
|
||||
"""
|
||||
# Create models
|
||||
with variable_scope.variable_scope(generator_scope) as gen_scope:
|
||||
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
|
||||
generated_data = generator_fn(generator_inputs)
|
||||
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
|
||||
discriminator_gen_outputs = discriminator_fn(generated_data,
|
||||
generator_inputs)
|
||||
with variable_scope.variable_scope(dis_scope, reuse=True):
|
||||
real_data = ops.convert_to_tensor(real_data)
|
||||
discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
|
||||
|
||||
if check_shapes:
|
||||
if not generated_data.shape.is_compatible_with(real_data.shape):
|
||||
raise ValueError(
|
||||
'Generator output shape (%s) must be the same shape as real data '
|
||||
'(%s).' % (generated_data.shape, real_data.shape))
|
||||
|
||||
# Get model-specific variables.
|
||||
generator_variables = variables_lib.get_trainable_variables(gen_scope)
|
||||
discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
|
||||
|
||||
return namedtuples.GANModel(
|
||||
generator_inputs,
|
||||
generated_data,
|
||||
generator_variables,
|
||||
gen_scope,
|
||||
generator_fn,
|
||||
real_data,
|
||||
discriminator_real_outputs,
|
||||
discriminator_gen_outputs,
|
||||
discriminator_variables,
|
||||
dis_scope,
|
||||
discriminator_fn)
|
||||
|
||||
|
||||
def _validate_distributions(distributions_l, noise_l):
|
||||
if not isinstance(distributions_l, (tuple, list)):
|
||||
raise ValueError('`predicted_distributions` must be a list. Instead, found '
|
||||
'%s.' % type(distributions_l))
|
||||
for dist in distributions_l:
|
||||
if not isinstance(dist, ds.Distribution):
|
||||
raise ValueError('Every element in `predicted_distributions` must be a '
|
||||
'`tf.Distribution`. Instead, found %s.' % type(dist))
|
||||
if len(distributions_l) != len(noise_l):
|
||||
raise ValueError('Length of `predicted_distributions` %i must be the same '
|
||||
'as the length of structured noise %i.' %
|
||||
(len(distributions_l), len(noise_l)))
|
||||
|
||||
|
||||
def infogan_model(
|
||||
# Lambdas defining models.
|
||||
generator_fn,
|
||||
discriminator_fn,
|
||||
# Real data and conditioning.
|
||||
real_data,
|
||||
unstructured_generator_inputs,
|
||||
structured_generator_inputs,
|
||||
# Optional scopes.
|
||||
generator_scope='Generator',
|
||||
discriminator_scope='Discriminator'):
|
||||
"""Returns an InfoGAN model outputs and variables.
|
||||
|
||||
See https://arxiv.org/abs/1606.03657 for more details.
|
||||
|
||||
Args:
|
||||
generator_fn: A python lambda that takes a list of Tensors as inputs and
|
||||
returns the outputs of the GAN generator.
|
||||
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
||||
and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
|
||||
`logits` are in the range [-inf, inf], and `distribution_list` is a list
|
||||
of Tensorflow distributions representing the predicted noise distribution
|
||||
of the ith structure noise.
|
||||
real_data: A Tensor representing the real data.
|
||||
unstructured_generator_inputs: A list of Tensors to the generator.
|
||||
These tensors represent the unstructured noise or conditioning.
|
||||
structured_generator_inputs: A list of Tensors to the generator.
|
||||
These tensors must have high mutual information with the recognizer.
|
||||
generator_scope: Optional generator variable scope. Useful if you want to
|
||||
reuse a subgraph that has already been created.
|
||||
discriminator_scope: Optional discriminator variable scope. Useful if you
|
||||
want to reuse a subgraph that has already been created.
|
||||
|
||||
Returns:
|
||||
An InfoGANModel namedtuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If the generator outputs a Tensor that isn't the same shape as
|
||||
`real_data`.
|
||||
ValueError: If the discriminator output is malformed.
|
||||
"""
|
||||
# Create models
|
||||
with variable_scope.variable_scope(generator_scope) as gen_scope:
|
||||
unstructured_generator_inputs = _convert_tensor_or_l_or_d(
|
||||
unstructured_generator_inputs)
|
||||
structured_generator_inputs = _convert_tensor_or_l_or_d(
|
||||
structured_generator_inputs)
|
||||
generator_inputs = (
|
||||
unstructured_generator_inputs + structured_generator_inputs)
|
||||
generated_data = generator_fn(generator_inputs)
|
||||
with variable_scope.variable_scope(discriminator_scope) as disc_scope:
|
||||
dis_gen_outputs, predicted_distributions = discriminator_fn(
|
||||
generated_data, generator_inputs)
|
||||
_validate_distributions(predicted_distributions, structured_generator_inputs)
|
||||
with variable_scope.variable_scope(disc_scope, reuse=True):
|
||||
real_data = ops.convert_to_tensor(real_data)
|
||||
dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
|
||||
|
||||
if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
|
||||
raise ValueError(
|
||||
'Generator output shape (%s) must be the same shape as real data '
|
||||
'(%s).' % (generated_data.get_shape(), real_data.get_shape()))
|
||||
|
||||
# Get model-specific variables.
|
||||
generator_variables = variables_lib.get_trainable_variables(gen_scope)
|
||||
discriminator_variables = variables_lib.get_trainable_variables(
|
||||
disc_scope)
|
||||
|
||||
return namedtuples.InfoGANModel(
|
||||
generator_inputs,
|
||||
generated_data,
|
||||
generator_variables,
|
||||
gen_scope,
|
||||
generator_fn,
|
||||
real_data,
|
||||
dis_real_outputs,
|
||||
dis_gen_outputs,
|
||||
discriminator_variables,
|
||||
disc_scope,
|
||||
lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API
|
||||
structured_generator_inputs,
|
||||
predicted_distributions)
|
||||
|
||||
|
||||
def _validate_acgan_discriminator_outputs(discriminator_output):
|
||||
try:
|
||||
a, b = discriminator_output
|
||||
except (TypeError, ValueError):
|
||||
raise TypeError(
|
||||
'A discriminator function for ACGAN must output a tuple '
|
||||
'consisting of (discrimination logits, classification logits).')
|
||||
return a, b
|
||||
|
||||
|
||||
def acgan_model(
|
||||
# Lambdas defining models.
|
||||
generator_fn,
|
||||
discriminator_fn,
|
||||
# Real data and conditioning.
|
||||
real_data,
|
||||
generator_inputs,
|
||||
one_hot_labels,
|
||||
# Optional scopes.
|
||||
generator_scope='Generator',
|
||||
discriminator_scope='Discriminator',
|
||||
check_shapes=True):
|
||||
"""Returns an ACGANModel contains all the pieces needed for ACGAN training.
|
||||
|
||||
The `acgan_model` is the same as the `gan_model` with the only difference
|
||||
being that the discriminator additionally outputs logits to classify the input
|
||||
(real or generated).
|
||||
Therefore, an explicit field holding one_hot_labels is necessary, as well as a
|
||||
discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
|
||||
classification.
|
||||
|
||||
See https://arxiv.org/abs/1610.09585 for more details.
|
||||
|
||||
Args:
|
||||
generator_fn: A python lambda that takes `generator_inputs` as inputs and
|
||||
returns the outputs of the GAN generator.
|
||||
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
||||
and `generator_inputs`. Outputs a tuple consisting of two Tensors:
|
||||
(1) real/fake logits in the range [-inf, inf]
|
||||
(2) classification logits in the range [-inf, inf]
|
||||
real_data: A Tensor representing the real data.
|
||||
generator_inputs: A Tensor or list of Tensors to the generator. In the
|
||||
vanilla GAN case, this might be a single noise Tensor. In the conditional
|
||||
GAN case, this might be the generator's conditioning.
|
||||
one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
|
||||
acgan_loss.
|
||||
generator_scope: Optional generator variable scope. Useful if you want to
|
||||
reuse a subgraph that has already been created.
|
||||
discriminator_scope: Optional discriminator variable scope. Useful if you
|
||||
want to reuse a subgraph that has already been created.
|
||||
check_shapes: If `True`, check that generator produces Tensors that are the
|
||||
same shape as real data. Otherwise, skip this check.
|
||||
|
||||
Returns:
|
||||
A ACGANModel namedtuple.
|
||||
|
||||
Raises:
|
||||
ValueError: If the generator outputs a Tensor that isn't the same shape as
|
||||
`real_data`.
|
||||
TypeError: If the discriminator does not output a tuple consisting of
|
||||
(discrimination logits, classification logits).
|
||||
"""
|
||||
# Create models
|
||||
with variable_scope.variable_scope(generator_scope) as gen_scope:
|
||||
generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
|
||||
generated_data = generator_fn(generator_inputs)
|
||||
with variable_scope.variable_scope(discriminator_scope) as dis_scope:
|
||||
(discriminator_gen_outputs, discriminator_gen_classification_logits
|
||||
) = _validate_acgan_discriminator_outputs(
|
||||
discriminator_fn(generated_data, generator_inputs))
|
||||
with variable_scope.variable_scope(dis_scope, reuse=True):
|
||||
real_data = ops.convert_to_tensor(real_data)
|
||||
(discriminator_real_outputs, discriminator_real_classification_logits
|
||||
) = _validate_acgan_discriminator_outputs(
|
||||
discriminator_fn(real_data, generator_inputs))
|
||||
if check_shapes:
|
||||
if not generated_data.shape.is_compatible_with(real_data.shape):
|
||||
raise ValueError(
|
||||
'Generator output shape (%s) must be the same shape as real data '
|
||||
'(%s).' % (generated_data.shape, real_data.shape))
|
||||
|
||||
# Get model-specific variables.
|
||||
generator_variables = variables_lib.get_trainable_variables(gen_scope)
|
||||
discriminator_variables = variables_lib.get_trainable_variables(
|
||||
dis_scope)
|
||||
|
||||
return namedtuples.ACGANModel(
|
||||
generator_inputs, generated_data, generator_variables, gen_scope,
|
||||
generator_fn, real_data, discriminator_real_outputs,
|
||||
discriminator_gen_outputs, discriminator_variables, dis_scope,
|
||||
discriminator_fn, one_hot_labels,
|
||||
discriminator_real_classification_logits,
|
||||
discriminator_gen_classification_logits)
|
||||
|
||||
|
||||
def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
|
||||
if isinstance(aux_loss_weight, ops.Tensor):
|
||||
aux_loss_weight.shape.assert_is_compatible_with([])
|
||||
with ops.control_dependencies(
|
||||
[check_ops.assert_greater_equal(aux_loss_weight, 0.0)]):
|
||||
aux_loss_weight = array_ops.identity(aux_loss_weight)
|
||||
elif aux_loss_weight is not None and aux_loss_weight < 0:
|
||||
raise ValueError('`%s` must be greater than 0. Instead, was %s' %
|
||||
(name, aux_loss_weight))
|
||||
return aux_loss_weight
|
||||
|
||||
|
||||
def _use_aux_loss(aux_loss_weight):
|
||||
if aux_loss_weight is not None:
|
||||
if not isinstance(aux_loss_weight, ops.Tensor):
|
||||
return aux_loss_weight > 0
|
||||
else:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def gan_loss(
|
||||
# GANModel.
|
||||
model,
|
||||
# Loss functions.
|
||||
generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
|
||||
discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
|
||||
# Auxiliary losses.
|
||||
gradient_penalty_weight=None,
|
||||
gradient_penalty_epsilon=1e-10,
|
||||
mutual_information_penalty_weight=None,
|
||||
aux_cond_generator_weight=None,
|
||||
aux_cond_discriminator_weight=None,
|
||||
# Options.
|
||||
add_summaries=True):
|
||||
"""Returns losses necessary to train generator and discriminator.
|
||||
|
||||
Args:
|
||||
model: A GANModel tuple.
|
||||
generator_loss_fn: The loss function on the generator. Takes a GANModel
|
||||
tuple.
|
||||
discriminator_loss_fn: The loss function on the discriminator. Takes a
|
||||
GANModel tuple.
|
||||
gradient_penalty_weight: If not `None`, must be a non-negative Python number
|
||||
or Tensor indicating how much to weight the gradient penalty. See
|
||||
https://arxiv.org/pdf/1704.00028.pdf for more details.
|
||||
gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
|
||||
small positive value used by the gradient penalty function for numerical
|
||||
stability. Note some applications will need to increase this value to
|
||||
avoid NaNs.
|
||||
mutual_information_penalty_weight: If not `None`, must be a non-negative
|
||||
Python number or Tensor indicating how much to weight the mutual
|
||||
information penalty. See https://arxiv.org/abs/1606.03657 for more
|
||||
details.
|
||||
aux_cond_generator_weight: If not None: add a classification loss as in
|
||||
https://arxiv.org/abs/1610.09585
|
||||
aux_cond_discriminator_weight: If not None: add a classification loss as in
|
||||
https://arxiv.org/abs/1610.09585
|
||||
add_summaries: Whether or not to add summaries for the losses.
|
||||
|
||||
Returns:
|
||||
A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
|
||||
regularization losses.
|
||||
|
||||
Raises:
|
||||
ValueError: If any of the auxiliary loss weights is provided and negative.
|
||||
ValueError: If `mutual_information_penalty_weight` is provided, but the
|
||||
`model` isn't an `InfoGANModel`.
|
||||
"""
|
||||
# Validate arguments.
|
||||
gradient_penalty_weight = _validate_aux_loss_weight(gradient_penalty_weight,
|
||||
'gradient_penalty_weight')
|
||||
mutual_information_penalty_weight = _validate_aux_loss_weight(
|
||||
mutual_information_penalty_weight, 'infogan_weight')
|
||||
aux_cond_generator_weight = _validate_aux_loss_weight(
|
||||
aux_cond_generator_weight, 'aux_cond_generator_weight')
|
||||
aux_cond_discriminator_weight = _validate_aux_loss_weight(
|
||||
aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
|
||||
|
||||
# Verify configuration for mutual information penalty
|
||||
if (_use_aux_loss(mutual_information_penalty_weight) and
|
||||
not isinstance(model, namedtuples.InfoGANModel)):
|
||||
raise ValueError(
|
||||
'When `mutual_information_penalty_weight` is provided, `model` must be '
|
||||
'an `InfoGANModel`. Instead, was %s.' % type(model))
|
||||
|
||||
# Verify configuration for mutual auxiliary condition loss (ACGAN).
|
||||
if ((_use_aux_loss(aux_cond_generator_weight) or
|
||||
_use_aux_loss(aux_cond_discriminator_weight)) and
|
||||
not isinstance(model, namedtuples.ACGANModel)):
|
||||
raise ValueError(
|
||||
'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
|
||||
'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
|
||||
type(model))
|
||||
|
||||
# Create standard losses.
|
||||
gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
|
||||
dis_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
|
||||
|
||||
# Add optional extra losses.
|
||||
if _use_aux_loss(gradient_penalty_weight):
|
||||
gp_loss = tfgan_losses.wasserstein_gradient_penalty(
|
||||
model, epsilon=gradient_penalty_epsilon, add_summaries=add_summaries)
|
||||
dis_loss += gradient_penalty_weight * gp_loss
|
||||
if _use_aux_loss(mutual_information_penalty_weight):
|
||||
info_loss = tfgan_losses.mutual_information_penalty(
|
||||
model, add_summaries=add_summaries)
|
||||
dis_loss += mutual_information_penalty_weight * info_loss
|
||||
gen_loss += mutual_information_penalty_weight * info_loss
|
||||
if _use_aux_loss(aux_cond_generator_weight):
|
||||
ac_gen_loss = tfgan_losses.acgan_generator_loss(
|
||||
model, add_summaries=add_summaries)
|
||||
gen_loss += aux_cond_generator_weight * ac_gen_loss
|
||||
if _use_aux_loss(aux_cond_discriminator_weight):
|
||||
ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
|
||||
model, add_summaries=add_summaries)
|
||||
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
|
||||
# Gathers auxilliary losses.
|
||||
if model.generator_scope:
|
||||
gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
|
||||
else:
|
||||
gen_reg_loss = 0
|
||||
if model.discriminator_scope:
|
||||
dis_reg_loss = losses.get_regularization_loss(
|
||||
model.discriminator_scope.name)
|
||||
else:
|
||||
dis_reg_loss = 0
|
||||
|
||||
return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
|
||||
|
||||
|
||||
def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
|
||||
"""Gets generator and discriminator update ops.
|
||||
|
||||
Args:
|
||||
kwargs: A dictionary of kwargs to be passed to `create_train_op`.
|
||||
`update_ops` is removed, if present.
|
||||
gen_scope: A scope for the generator.
|
||||
dis_scope: A scope for the discriminator.
|
||||
check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
|
||||
unused update ops.
|
||||
|
||||
Returns:
|
||||
A 2-tuple of (generator update ops, discriminator train ops).
|
||||
|
||||
Raises:
|
||||
ValueError: If there are update ops outside of the generator or
|
||||
discriminator scopes.
|
||||
"""
|
||||
if 'update_ops' in kwargs:
|
||||
update_ops = set(kwargs['update_ops'])
|
||||
del kwargs['update_ops']
|
||||
else:
|
||||
update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
|
||||
|
||||
all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope))
|
||||
all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope))
|
||||
|
||||
if check_for_unused_ops:
|
||||
unused_ops = update_ops - all_gen_ops - all_dis_ops
|
||||
if unused_ops:
|
||||
raise ValueError('There are unused update ops: %s' % unused_ops)
|
||||
|
||||
gen_update_ops = list(all_gen_ops & update_ops)
|
||||
dis_update_ops = list(all_dis_ops & update_ops)
|
||||
|
||||
return gen_update_ops, dis_update_ops
|
||||
|
||||
|
||||
def gan_train_ops(
|
||||
model, # GANModel
|
||||
loss, # GANLoss
|
||||
generator_optimizer,
|
||||
discriminator_optimizer,
|
||||
# Optional check flags.
|
||||
check_for_unused_update_ops=True,
|
||||
# Optional args to pass directly to the `create_train_op`.
|
||||
**kwargs):
|
||||
"""Returns GAN train ops.
|
||||
|
||||
The highest-level call in TFGAN. It is composed of functions that can also
|
||||
be called, should a user require more control over some part of the GAN
|
||||
training process.
|
||||
|
||||
Args:
|
||||
model: A GANModel.
|
||||
loss: A GANLoss.
|
||||
generator_optimizer: The optimizer for generator updates.
|
||||
discriminator_optimizer: The optimizer for the discriminator updates.
|
||||
check_for_unused_update_ops: If `True`, throws an exception if there are
|
||||
update ops outside of the generator or discriminator scopes.
|
||||
**kwargs: Keyword args to pass directly to
|
||||
`training.create_train_op` for both the generator and
|
||||
discriminator train op.
|
||||
|
||||
Returns:
|
||||
A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
|
||||
be used to train a generator/discriminator pair.
|
||||
"""
|
||||
# Create global step increment op.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
global_step_inc = global_step.assign_add(1)
|
||||
|
||||
# Get generator and discriminator update ops. We split them so that update
|
||||
# ops aren't accidentally run multiple times. For now, throw an error if
|
||||
# there are update ops that aren't associated with either the generator or
|
||||
# the discriminator. Might modify the `kwargs` dictionary.
|
||||
gen_update_ops, dis_update_ops = _get_update_ops(
|
||||
kwargs, model.generator_scope.name, model.discriminator_scope.name,
|
||||
check_for_unused_update_ops)
|
||||
|
||||
generator_global_step = None
|
||||
if isinstance(generator_optimizer,
|
||||
sync_replicas_optimizer.SyncReplicasOptimizer):
|
||||
# TODO(joelshor): Figure out a way to get this work without including the
|
||||
# dummy global step in the checkpoint.
|
||||
# WARNING: Making this variable a local variable causes sync replicas to
|
||||
# hang forever.
|
||||
generator_global_step = variable_scope.get_variable(
|
||||
'dummy_global_step_generator',
|
||||
shape=[],
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
gen_update_ops += [generator_global_step.assign(global_step)]
|
||||
with ops.name_scope('generator_train'):
|
||||
gen_train_op = training.create_train_op(
|
||||
total_loss=loss.generator_loss,
|
||||
optimizer=generator_optimizer,
|
||||
variables_to_train=model.generator_variables,
|
||||
global_step=generator_global_step,
|
||||
update_ops=gen_update_ops,
|
||||
**kwargs)
|
||||
|
||||
discriminator_global_step = None
|
||||
if isinstance(discriminator_optimizer,
|
||||
sync_replicas_optimizer.SyncReplicasOptimizer):
|
||||
# See comment above `generator_global_step`.
|
||||
discriminator_global_step = variable_scope.get_variable(
|
||||
'dummy_global_step_discriminator',
|
||||
shape=[],
|
||||
dtype=dtypes.int64,
|
||||
initializer=init_ops.zeros_initializer(),
|
||||
trainable=False,
|
||||
collections=[ops.GraphKeys.GLOBAL_VARIABLES])
|
||||
dis_update_ops += [discriminator_global_step.assign(global_step)]
|
||||
with ops.name_scope('discriminator_train'):
|
||||
disc_train_op = training.create_train_op(
|
||||
total_loss=loss.discriminator_loss,
|
||||
optimizer=discriminator_optimizer,
|
||||
variables_to_train=model.discriminator_variables,
|
||||
global_step=discriminator_global_step,
|
||||
update_ops=dis_update_ops,
|
||||
**kwargs)
|
||||
|
||||
return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc)
|
||||
|
||||
|
||||
# TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
|
||||
# Image Compression` (https://arxiv.org/abs/1705.05823)
|
||||
class RunTrainOpsHook(session_run_hook.SessionRunHook):
|
||||
"""A hook to run train ops a fixed number of times."""
|
||||
|
||||
def __init__(self, train_ops, train_steps):
|
||||
"""Run train ops a certain number of times.
|
||||
|
||||
Args:
|
||||
train_ops: A train op or iterable of train ops to run.
|
||||
train_steps: The number of times to run the op(s).
|
||||
"""
|
||||
if not isinstance(train_ops, (list, tuple)):
|
||||
train_ops = [train_ops]
|
||||
self._train_ops = train_ops
|
||||
self._train_steps = train_steps
|
||||
|
||||
def before_run(self, run_context):
|
||||
for _ in range(self._train_steps):
|
||||
run_context.session.run(self._train_ops)
|
||||
|
||||
|
||||
def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
|
||||
"""Returns a hooks function for sequential GAN training.
|
||||
|
||||
Args:
|
||||
train_steps: A `GANTrainSteps` tuple that determines how many generator
|
||||
and discriminator training steps to take.
|
||||
|
||||
Returns:
|
||||
A function that takes a GANTrainOps tuple and returns a list of hooks.
|
||||
"""
|
||||
def get_hooks(train_ops):
|
||||
generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
|
||||
train_steps.generator_train_steps)
|
||||
discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
|
||||
train_steps.discriminator_train_steps)
|
||||
return [generator_hook, discriminator_hook]
|
||||
return get_hooks
|
||||
|
||||
|
||||
def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
|
||||
"""Returns a hooks function for sequential GAN training.
|
||||
|
||||
When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
|
||||
ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
|
||||
|
||||
The order of steps taken is:
|
||||
1) Combined generator and discriminator steps
|
||||
2) Generator only steps, if any remain
|
||||
3) Discriminator only steps, if any remain
|
||||
|
||||
**NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
|
||||
for the generator and discriminator simultaneously whenever possible. This
|
||||
reduces the number of `tf.Session` calls, and can also change the training
|
||||
semantics.
|
||||
|
||||
To illustrate the difference look at the following example:
|
||||
|
||||
`train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
|
||||
`get_sequential_train_hooks` to make 8 session calls:
|
||||
1) 3 generator steps
|
||||
2) 5 discriminator steps
|
||||
|
||||
In contrast, `get_joint_train_steps` will make 5 session calls:
|
||||
1) 3 generator + discriminator steps
|
||||
2) 2 discriminator steps
|
||||
|
||||
Args:
|
||||
train_steps: A `GANTrainSteps` tuple that determines how many generator
|
||||
and discriminator training steps to take.
|
||||
|
||||
Returns:
|
||||
A function that takes a GANTrainOps tuple and returns a list of hooks.
|
||||
"""
|
||||
g_steps = train_steps.generator_train_steps
|
||||
d_steps = train_steps.discriminator_train_steps
|
||||
# Get the number of each type of step that should be run.
|
||||
num_d_and_g_steps = min(g_steps, d_steps)
|
||||
num_g_steps = g_steps - num_d_and_g_steps
|
||||
num_d_steps = d_steps - num_d_and_g_steps
|
||||
|
||||
def get_hooks(train_ops):
|
||||
g_op = train_ops.generator_train_op
|
||||
d_op = train_ops.discriminator_train_op
|
||||
|
||||
joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
|
||||
g_hook = RunTrainOpsHook(g_op, num_g_steps)
|
||||
d_hook = RunTrainOpsHook(d_op, num_d_steps)
|
||||
|
||||
return [joint_hook, g_hook, d_hook]
|
||||
return get_hooks
|
||||
|
||||
|
||||
# TODO(joelshor): This function currently returns the global step. Find a
|
||||
# good way for it to return the generator, discriminator, and final losses.
|
||||
def gan_train(
|
||||
train_ops,
|
||||
logdir,
|
||||
get_hooks_fn=get_sequential_train_hooks(),
|
||||
master='',
|
||||
is_chief=True,
|
||||
scaffold=None,
|
||||
hooks=None,
|
||||
chief_only_hooks=None,
|
||||
save_checkpoint_secs=600,
|
||||
save_summaries_steps=100,
|
||||
config=None):
|
||||
"""A wrapper around `contrib.training.train` that uses GAN hooks.
|
||||
|
||||
Args:
|
||||
train_ops: A GANTrainOps named tuple.
|
||||
logdir: The directory where the graph and checkpoints are saved.
|
||||
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
|
||||
of hooks.
|
||||
master: The URL of the master.
|
||||
is_chief: Specifies whether or not the training is being run by the primary
|
||||
replica during replica training.
|
||||
scaffold: An tf.train.Scaffold instance.
|
||||
hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
|
||||
training loop.
|
||||
chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
|
||||
inside the training loop for the chief trainer only.
|
||||
save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
|
||||
using a default checkpoint saver. If `save_checkpoint_secs` is set to
|
||||
`None`, then the default checkpoint saver isn't used.
|
||||
save_summaries_steps: The frequency, in number of global steps, that the
|
||||
summaries are written to disk using a default summary saver. If
|
||||
`save_summaries_steps` is set to `None`, then the default summary saver
|
||||
isn't used.
|
||||
config: An instance of `tf.ConfigProto`.
|
||||
|
||||
Returns:
|
||||
Output of the call to `training.train`.
|
||||
"""
|
||||
new_hooks = get_hooks_fn(train_ops)
|
||||
if hooks is not None:
|
||||
hooks = list(hooks) + list(new_hooks)
|
||||
else:
|
||||
hooks = new_hooks
|
||||
return training.train(
|
||||
train_ops.global_step_inc_op,
|
||||
logdir,
|
||||
master=master,
|
||||
is_chief=is_chief,
|
||||
scaffold=scaffold,
|
||||
hooks=hooks,
|
||||
chief_only_hooks=chief_only_hooks,
|
||||
save_checkpoint_secs=save_checkpoint_secs,
|
||||
save_summaries_steps=save_summaries_steps,
|
||||
config=config)
|
||||
|
||||
|
||||
def get_sequential_train_steps(
|
||||
train_steps=namedtuples.GANTrainSteps(1, 1)):
|
||||
"""Returns a thin wrapper around slim.learning.train_step, for GANs.
|
||||
|
||||
This function is to provide support for the Supervisor. For new code, please
|
||||
use `MonitoredSession` and `get_sequential_train_hooks`.
|
||||
|
||||
Args:
|
||||
train_steps: A `GANTrainSteps` tuple that determines how many generator
|
||||
and discriminator training steps to take.
|
||||
|
||||
Returns:
|
||||
A function that can be used for `train_step_fn` for GANs.
|
||||
"""
|
||||
|
||||
def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
|
||||
"""A thin wrapper around slim.learning.train_step, for GANs.
|
||||
|
||||
Args:
|
||||
sess: A Tensorflow session.
|
||||
train_ops: A GANTrainOps tuple of train ops to run.
|
||||
global_step: The global step.
|
||||
train_step_kwargs: Dictionary controlling `train_step` behavior.
|
||||
|
||||
Returns:
|
||||
A scalar final loss and a bool whether or not the train loop should stop.
|
||||
"""
|
||||
# Only run `should_stop` at the end, if required. Make a local copy of
|
||||
# `train_step_kwargs`, if necessary, so as not to modify the caller's
|
||||
# dictionary.
|
||||
should_stop_op, train_kwargs = None, train_step_kwargs
|
||||
if 'should_stop' in train_step_kwargs:
|
||||
should_stop_op = train_step_kwargs['should_stop']
|
||||
train_kwargs = train_step_kwargs.copy()
|
||||
del train_kwargs['should_stop']
|
||||
|
||||
# Run generator training steps.
|
||||
gen_loss = 0
|
||||
for _ in range(train_steps.generator_train_steps):
|
||||
cur_gen_loss, _ = slim_learning.train_step(
|
||||
sess, train_ops.generator_train_op, global_step, train_kwargs)
|
||||
gen_loss += cur_gen_loss
|
||||
|
||||
# Run discriminator training steps.
|
||||
dis_loss = 0
|
||||
for _ in range(train_steps.discriminator_train_steps):
|
||||
cur_dis_loss, _ = slim_learning.train_step(
|
||||
sess, train_ops.discriminator_train_op, global_step, train_kwargs)
|
||||
dis_loss += cur_dis_loss
|
||||
|
||||
sess.run(train_ops.global_step_inc_op)
|
||||
|
||||
# Run the `should_stop` op after the global step has been incremented, so
|
||||
# that the `should_stop` aligns with the proper `global_step` count.
|
||||
if should_stop_op is not None:
|
||||
should_stop = sess.run(should_stop_op)
|
||||
else:
|
||||
should_stop = False
|
||||
|
||||
return gen_loss + dis_loss, should_stop
|
||||
|
||||
return sequential_train_steps
|
745
tensorflow/contrib/gan/python/train_test.py
Normal file
745
tensorflow/contrib/gan/python/train_test.py
Normal file
@ -0,0 +1,745 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for gan.python.train."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as variables_lib
|
||||
from tensorflow.contrib.gan.python import namedtuples
|
||||
from tensorflow.contrib.gan.python import train
|
||||
from tensorflow.contrib.slim.python.slim import learning as slim_learning
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.distributions import categorical
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import basic_session_run_hooks
|
||||
from tensorflow.python.training import coordinator
|
||||
from tensorflow.python.training import gradient_descent
|
||||
from tensorflow.python.training import sync_replicas_optimizer
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
def generator_model(inputs):
|
||||
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs
|
||||
|
||||
|
||||
class Generator(object):
|
||||
|
||||
def __call__(self, inputs):
|
||||
return generator_model(inputs)
|
||||
|
||||
|
||||
def infogan_generator_model(inputs):
|
||||
return variable_scope.get_variable('dummy_g', initializer=2.0) * inputs[0]
|
||||
|
||||
|
||||
class InfoGANGenerator(object):
|
||||
|
||||
def __call__(self, inputs):
|
||||
return infogan_generator_model(inputs)
|
||||
|
||||
|
||||
def discriminator_model(inputs, _):
|
||||
return variable_scope.get_variable('dummy_d', initializer=2.0) * inputs
|
||||
|
||||
|
||||
class Discriminator(object):
|
||||
|
||||
def __call__(self, inputs, _):
|
||||
return discriminator_model(inputs, _)
|
||||
|
||||
|
||||
def infogan_discriminator_model(inputs, _):
|
||||
return (variable_scope.get_variable('dummy_d', initializer=2.0) * inputs,
|
||||
[categorical.Categorical([1.0])])
|
||||
|
||||
|
||||
class InfoGANDiscriminator(object):
|
||||
|
||||
def __call__(self, inputs, _):
|
||||
return infogan_discriminator_model(inputs, _)
|
||||
|
||||
|
||||
def acgan_discriminator_model(inputs, _, num_classes=10):
|
||||
return (discriminator_model(inputs, _), array_ops.one_hot(
|
||||
# TODO(haeusser): infer batch size from input
|
||||
random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
|
||||
num_classes))
|
||||
|
||||
|
||||
class ACGANDiscriminator(object):
|
||||
|
||||
def __call__(self, inputs, _, num_classes=10):
|
||||
return (discriminator_model(inputs, _), array_ops.one_hot(
|
||||
# TODO(haeusser): infer batch size from input
|
||||
random_ops.random_uniform([3], maxval=num_classes, dtype=dtypes.int32),
|
||||
num_classes))
|
||||
|
||||
|
||||
def get_gan_model():
|
||||
# TODO(joelshor): Find a better way of creating a variable scope.
|
||||
with variable_scope.variable_scope('generator') as gen_scope:
|
||||
pass
|
||||
with variable_scope.variable_scope('discriminator') as dis_scope:
|
||||
pass
|
||||
return namedtuples.GANModel(
|
||||
generator_inputs=None,
|
||||
generated_data=None,
|
||||
generator_variables=None,
|
||||
generator_scope=gen_scope,
|
||||
generator_fn=generator_model,
|
||||
real_data=array_ops.ones([1, 2, 3]),
|
||||
discriminator_real_outputs=array_ops.ones([1, 2, 3]),
|
||||
discriminator_gen_outputs=array_ops.ones([1, 2, 3]),
|
||||
discriminator_variables=None,
|
||||
discriminator_scope=dis_scope,
|
||||
discriminator_fn=discriminator_model)
|
||||
|
||||
|
||||
def get_callable_gan_model():
|
||||
ganmodel = get_gan_model()
|
||||
return ganmodel._replace(
|
||||
generator_fn=Generator(),
|
||||
discriminator_fn=Discriminator())
|
||||
|
||||
|
||||
def create_gan_model():
|
||||
return train.gan_model(
|
||||
generator_model,
|
||||
discriminator_model,
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=random_ops.random_normal([1, 2]))
|
||||
|
||||
|
||||
def create_callable_gan_model():
|
||||
return train.gan_model(
|
||||
Generator(),
|
||||
Discriminator(),
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=random_ops.random_normal([1, 2]))
|
||||
|
||||
|
||||
def get_infogan_model():
|
||||
return namedtuples.InfoGANModel(
|
||||
*get_gan_model(),
|
||||
structured_generator_inputs=[constant_op.constant(0)],
|
||||
predicted_distributions=[categorical.Categorical([1.0])])
|
||||
|
||||
|
||||
def get_callable_infogan_model():
|
||||
return namedtuples.InfoGANModel(
|
||||
*get_callable_gan_model(),
|
||||
structured_generator_inputs=[constant_op.constant(0)],
|
||||
predicted_distributions=[categorical.Categorical([1.0])])
|
||||
|
||||
|
||||
def create_infogan_model():
|
||||
return train.infogan_model(
|
||||
infogan_generator_model,
|
||||
infogan_discriminator_model,
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
unstructured_generator_inputs=[],
|
||||
structured_generator_inputs=[random_ops.random_normal([1, 2])])
|
||||
|
||||
|
||||
def create_callable_infogan_model():
|
||||
return train.infogan_model(
|
||||
InfoGANGenerator(),
|
||||
InfoGANDiscriminator(),
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
unstructured_generator_inputs=[],
|
||||
structured_generator_inputs=[random_ops.random_normal([1, 2])])
|
||||
|
||||
|
||||
def get_acgan_model():
|
||||
return namedtuples.ACGANModel(
|
||||
*get_gan_model(),
|
||||
one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
|
||||
discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
|
||||
discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
|
||||
|
||||
|
||||
def get_callable_acgan_model():
|
||||
return namedtuples.ACGANModel(
|
||||
*get_callable_gan_model(),
|
||||
one_hot_labels=array_ops.one_hot([0, 1, 2], 10),
|
||||
discriminator_real_classification_logits=array_ops.one_hot([0, 1, 3], 10),
|
||||
discriminator_gen_classification_logits=array_ops.one_hot([0, 1, 4], 10))
|
||||
|
||||
|
||||
def create_acgan_model():
|
||||
return train.acgan_model(
|
||||
generator_model,
|
||||
acgan_discriminator_model,
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=random_ops.random_normal([1, 2]),
|
||||
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
|
||||
|
||||
|
||||
def create_callable_acgan_model():
|
||||
return train.acgan_model(
|
||||
Generator(),
|
||||
ACGANDiscriminator(),
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=random_ops.random_normal([1, 2]),
|
||||
one_hot_labels=array_ops.one_hot([0, 1, 2], 10))
|
||||
|
||||
|
||||
def get_sync_optimizer():
|
||||
return sync_replicas_optimizer.SyncReplicasOptimizer(
|
||||
gradient_descent.GradientDescentOptimizer(learning_rate=1.0),
|
||||
replicas_to_aggregate=1)
|
||||
|
||||
|
||||
class GANModelTest(test.TestCase):
|
||||
"""Tests for `gan_model`."""
|
||||
|
||||
def _test_output_type_helper(self, create_fn, tuple_type):
|
||||
self.assertTrue(isinstance(create_fn(), tuple_type))
|
||||
|
||||
def test_output_type_gan(self):
|
||||
self._test_output_type_helper(get_gan_model, namedtuples.GANModel)
|
||||
|
||||
def test_output_type_callable_gan(self):
|
||||
self._test_output_type_helper(get_callable_gan_model, namedtuples.GANModel)
|
||||
|
||||
def test_output_type_infogan(self):
|
||||
self._test_output_type_helper(get_infogan_model, namedtuples.InfoGANModel)
|
||||
|
||||
def test_output_type_callable_infogan(self):
|
||||
self._test_output_type_helper(
|
||||
get_callable_infogan_model, namedtuples.InfoGANModel)
|
||||
|
||||
def test_output_type_acgan(self):
|
||||
self._test_output_type_helper(get_acgan_model, namedtuples.ACGANModel)
|
||||
|
||||
def test_output_type_callable_acgan(self):
|
||||
self._test_output_type_helper(
|
||||
get_callable_acgan_model, namedtuples.ACGANModel)
|
||||
|
||||
def test_no_shape_check(self):
|
||||
def dummy_generator_model(_):
|
||||
return (None, None)
|
||||
def dummy_discriminator_model(data, conditioning): # pylint: disable=unused-argument
|
||||
return 1
|
||||
with self.assertRaisesRegexp(AttributeError, 'object has no attribute'):
|
||||
train.gan_model(
|
||||
dummy_generator_model,
|
||||
dummy_discriminator_model,
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=array_ops.zeros([1]),
|
||||
check_shapes=True)
|
||||
train.gan_model(
|
||||
dummy_generator_model,
|
||||
dummy_discriminator_model,
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=array_ops.zeros([1]),
|
||||
check_shapes=False)
|
||||
|
||||
|
||||
class GANLossTest(test.TestCase):
|
||||
"""Tests for `gan_loss`."""
|
||||
|
||||
# Test output type.
|
||||
def _test_output_type_helper(self, get_gan_model_fn):
|
||||
loss = train.gan_loss(get_gan_model_fn(), add_summaries=True)
|
||||
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
|
||||
self.assertGreater(len(ops.get_collection(ops.GraphKeys.SUMMARIES)), 0)
|
||||
|
||||
def test_output_type_gan(self):
|
||||
self._test_output_type_helper(get_gan_model)
|
||||
|
||||
def test_output_type_callable_gan(self):
|
||||
self._test_output_type_helper(get_callable_gan_model)
|
||||
|
||||
def test_output_type_infogan(self):
|
||||
self._test_output_type_helper(get_infogan_model)
|
||||
|
||||
def test_output_type_callable_infogan(self):
|
||||
self._test_output_type_helper(get_callable_infogan_model)
|
||||
|
||||
def test_output_type_acgan(self):
|
||||
self._test_output_type_helper(get_acgan_model)
|
||||
|
||||
def test_output_type_callable_acgan(self):
|
||||
self._test_output_type_helper(get_callable_acgan_model)
|
||||
|
||||
# Test gradient penalty option.
|
||||
def _test_grad_penalty_helper(self, create_gan_model_fn):
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
loss_gp = train.gan_loss(model, gradient_penalty_weight=1.0)
|
||||
self.assertTrue(isinstance(loss_gp, namedtuples.GANLoss))
|
||||
|
||||
# Check values.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
loss_gen_np, loss_gen_gp_np = sess.run(
|
||||
[loss.generator_loss, loss_gp.generator_loss])
|
||||
loss_dis_np, loss_dis_gp_np = sess.run(
|
||||
[loss.discriminator_loss, loss_gp.discriminator_loss])
|
||||
|
||||
self.assertEqual(loss_gen_np, loss_gen_gp_np)
|
||||
self.assertTrue(loss_dis_np < loss_dis_gp_np)
|
||||
|
||||
def test_grad_penalty_gan(self):
|
||||
self._test_grad_penalty_helper(create_gan_model)
|
||||
|
||||
def test_grad_penalty_callable_gan(self):
|
||||
self._test_grad_penalty_helper(create_callable_gan_model)
|
||||
|
||||
def test_grad_penalty_infogan(self):
|
||||
self._test_grad_penalty_helper(create_infogan_model)
|
||||
|
||||
def test_grad_penalty_callable_infogan(self):
|
||||
self._test_grad_penalty_helper(create_callable_infogan_model)
|
||||
|
||||
def test_grad_penalty_acgan(self):
|
||||
self._test_grad_penalty_helper(create_acgan_model)
|
||||
|
||||
def test_grad_penalty_callable_acgan(self):
|
||||
self._test_grad_penalty_helper(create_callable_acgan_model)
|
||||
|
||||
# Test mutual information penalty option.
|
||||
def _test_mutual_info_penalty_helper(self, create_gan_model_fn):
|
||||
train.gan_loss(create_gan_model_fn(),
|
||||
mutual_information_penalty_weight=constant_op.constant(1.0))
|
||||
|
||||
def test_mutual_info_penalty_infogan(self):
|
||||
self._test_mutual_info_penalty_helper(get_infogan_model)
|
||||
|
||||
def test_mutual_info_penalty_callable_infogan(self):
|
||||
self._test_mutual_info_penalty_helper(get_callable_infogan_model)
|
||||
|
||||
# Test regularization loss.
|
||||
def _test_regularization_helper(self, get_gan_model_fn):
|
||||
# Evaluate losses without regularization.
|
||||
no_reg_loss = train.gan_loss(get_gan_model_fn())
|
||||
with self.test_session(use_gpu=True):
|
||||
no_reg_loss_gen_np = no_reg_loss.generator_loss.eval()
|
||||
no_reg_loss_dis_np = no_reg_loss.discriminator_loss.eval()
|
||||
|
||||
with ops.name_scope(get_gan_model_fn().generator_scope.name):
|
||||
ops.add_to_collection(
|
||||
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(3.0))
|
||||
with ops.name_scope(get_gan_model_fn().discriminator_scope.name):
|
||||
ops.add_to_collection(
|
||||
ops.GraphKeys.REGULARIZATION_LOSSES, constant_op.constant(2.0))
|
||||
|
||||
# Check that losses now include the correct regularization values.
|
||||
reg_loss = train.gan_loss(get_gan_model_fn())
|
||||
with self.test_session(use_gpu=True):
|
||||
reg_loss_gen_np = reg_loss.generator_loss.eval()
|
||||
reg_loss_dis_np = reg_loss.discriminator_loss.eval()
|
||||
|
||||
self.assertTrue(3.0, reg_loss_gen_np - no_reg_loss_gen_np)
|
||||
self.assertTrue(3.0, reg_loss_dis_np - no_reg_loss_dis_np)
|
||||
|
||||
def test_regularization_gan(self):
|
||||
self._test_regularization_helper(get_gan_model)
|
||||
|
||||
def test_regularization_callable_gan(self):
|
||||
self._test_regularization_helper(get_callable_gan_model)
|
||||
|
||||
def test_regularization_infogan(self):
|
||||
self._test_regularization_helper(get_infogan_model)
|
||||
|
||||
def test_regularization_callable_infogan(self):
|
||||
self._test_regularization_helper(get_callable_infogan_model)
|
||||
|
||||
def test_regularization_acgan(self):
|
||||
self._test_regularization_helper(get_acgan_model)
|
||||
|
||||
def test_regularization_callable_acgan(self):
|
||||
self._test_regularization_helper(get_callable_acgan_model)
|
||||
|
||||
# Test that ACGan models work.
|
||||
def _test_acgan_helper(self, create_gan_model_fn):
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
loss_ac_gen = train.gan_loss(model, aux_cond_generator_weight=1.0)
|
||||
loss_ac_dis = train.gan_loss(model, aux_cond_discriminator_weight=1.0)
|
||||
self.assertTrue(isinstance(loss, namedtuples.GANLoss))
|
||||
self.assertTrue(isinstance(loss_ac_gen, namedtuples.GANLoss))
|
||||
self.assertTrue(isinstance(loss_ac_dis, namedtuples.GANLoss))
|
||||
|
||||
# Check values.
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
loss_gen_np, loss_ac_gen_gen_np, loss_ac_dis_gen_np = sess.run(
|
||||
[loss.generator_loss,
|
||||
loss_ac_gen.generator_loss,
|
||||
loss_ac_dis.generator_loss])
|
||||
loss_dis_np, loss_ac_gen_dis_np, loss_ac_dis_dis_np = sess.run(
|
||||
[loss.discriminator_loss,
|
||||
loss_ac_gen.discriminator_loss,
|
||||
loss_ac_dis.discriminator_loss])
|
||||
|
||||
self.assertTrue(loss_gen_np < loss_dis_np)
|
||||
self.assertTrue(np.isscalar(loss_ac_gen_gen_np))
|
||||
self.assertTrue(np.isscalar(loss_ac_dis_gen_np))
|
||||
self.assertTrue(np.isscalar(loss_ac_gen_dis_np))
|
||||
self.assertTrue(np.isscalar(loss_ac_dis_dis_np))
|
||||
|
||||
def test_acgan(self):
|
||||
self._test_acgan_helper(create_acgan_model)
|
||||
|
||||
def test_callable_acgan(self):
|
||||
self._test_acgan_helper(create_callable_acgan_model)
|
||||
|
||||
def test_doesnt_crash_when_in_nested_scope(self):
|
||||
with variable_scope.variable_scope('outer_scope'):
|
||||
gan_model = train.gan_model(
|
||||
generator_model,
|
||||
discriminator_model,
|
||||
real_data=array_ops.zeros([1, 2]),
|
||||
generator_inputs=random_ops.random_normal([1, 2]))
|
||||
|
||||
# This should work inside a scope.
|
||||
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
|
||||
|
||||
# This should also work outside a scope.
|
||||
train.gan_loss(gan_model, gradient_penalty_weight=1.0)
|
||||
|
||||
|
||||
class GANTrainOpsTest(test.TestCase):
|
||||
"""Tests for `gan_train_ops`."""
|
||||
|
||||
def _test_output_type_helper(self, create_gan_model_fn):
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
|
||||
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
train_ops = train.gan_train_ops(
|
||||
model,
|
||||
loss,
|
||||
g_opt,
|
||||
d_opt,
|
||||
summarize_gradients=True,
|
||||
colocate_gradients_with_ops=True)
|
||||
|
||||
self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
|
||||
|
||||
def test_output_type_gan(self):
|
||||
self._test_output_type_helper(create_gan_model)
|
||||
|
||||
def test_output_type_callable_gan(self):
|
||||
self._test_output_type_helper(create_callable_gan_model)
|
||||
|
||||
def test_output_type_infogan(self):
|
||||
self._test_output_type_helper(create_infogan_model)
|
||||
|
||||
def test_output_type_callable_infogan(self):
|
||||
self._test_output_type_helper(create_callable_infogan_model)
|
||||
|
||||
def test_output_type_acgan(self):
|
||||
self._test_output_type_helper(create_acgan_model)
|
||||
|
||||
def test_output_type_callable_acgan(self):
|
||||
self._test_output_type_helper(create_callable_acgan_model)
|
||||
|
||||
# TODO(joelshor): Add a test to check that custom update op is run.
|
||||
def _test_unused_update_ops(self, create_gan_model_fn, provide_update_ops):
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
|
||||
# Add generator and discriminator update ops.
|
||||
with variable_scope.variable_scope(model.generator_scope):
|
||||
gen_update_count = variable_scope.get_variable('gen_count', initializer=0)
|
||||
gen_update_op = gen_update_count.assign_add(1)
|
||||
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, gen_update_op)
|
||||
with variable_scope.variable_scope(model.discriminator_scope):
|
||||
dis_update_count = variable_scope.get_variable('dis_count', initializer=0)
|
||||
dis_update_op = dis_update_count.assign_add(1)
|
||||
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, dis_update_op)
|
||||
|
||||
# Add an update op outside the generator and discriminator scopes.
|
||||
if provide_update_ops:
|
||||
kwargs = {'update_ops':
|
||||
[constant_op.constant(1.0), gen_update_op, dis_update_op]}
|
||||
else:
|
||||
ops.add_to_collection(ops.GraphKeys.UPDATE_OPS, constant_op.constant(1.0))
|
||||
kwargs = {}
|
||||
|
||||
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
|
||||
with self.assertRaisesRegexp(ValueError, 'There are unused update ops:'):
|
||||
train.gan_train_ops(model, loss, g_opt, d_opt,
|
||||
check_for_unused_update_ops=True, **kwargs)
|
||||
train_ops = train.gan_train_ops(
|
||||
model, loss, g_opt, d_opt, check_for_unused_update_ops=False, **kwargs)
|
||||
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
sess.run(variables.global_variables_initializer())
|
||||
self.assertEqual(0, gen_update_count.eval())
|
||||
self.assertEqual(0, dis_update_count.eval())
|
||||
|
||||
train_ops.generator_train_op.eval()
|
||||
self.assertEqual(1, gen_update_count.eval())
|
||||
self.assertEqual(0, dis_update_count.eval())
|
||||
|
||||
train_ops.discriminator_train_op.eval()
|
||||
self.assertEqual(1, gen_update_count.eval())
|
||||
self.assertEqual(1, dis_update_count.eval())
|
||||
|
||||
def test_unused_update_ops_gan(self):
|
||||
self._test_unused_update_ops(create_gan_model, False)
|
||||
|
||||
def test_unused_update_ops_gan_provideupdates(self):
|
||||
self._test_unused_update_ops(create_gan_model, True)
|
||||
|
||||
def test_unused_update_ops_callable_gan(self):
|
||||
self._test_unused_update_ops(create_callable_gan_model, False)
|
||||
|
||||
def test_unused_update_ops_callable_gan_provideupdates(self):
|
||||
self._test_unused_update_ops(create_callable_gan_model, True)
|
||||
|
||||
def test_unused_update_ops_infogan(self):
|
||||
self._test_unused_update_ops(create_infogan_model, False)
|
||||
|
||||
def test_unused_update_ops_infogan_provideupdates(self):
|
||||
self._test_unused_update_ops(create_infogan_model, True)
|
||||
|
||||
def test_unused_update_ops_callable_infogan(self):
|
||||
self._test_unused_update_ops(create_callable_infogan_model, False)
|
||||
|
||||
def test_unused_update_ops_callable_infogan_provideupdates(self):
|
||||
self._test_unused_update_ops(create_callable_infogan_model, True)
|
||||
|
||||
def test_unused_update_ops_acgan(self):
|
||||
self._test_unused_update_ops(create_acgan_model, False)
|
||||
|
||||
def test_unused_update_ops_acgan_provideupdates(self):
|
||||
self._test_unused_update_ops(create_acgan_model, True)
|
||||
|
||||
def test_unused_update_ops_callable_acgan(self):
|
||||
self._test_unused_update_ops(create_callable_acgan_model, False)
|
||||
|
||||
def test_unused_update_ops_callable_acgan_provideupdates(self):
|
||||
self._test_unused_update_ops(create_callable_acgan_model, True)
|
||||
|
||||
def _test_sync_replicas_helper(self, create_gan_model_fn):
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
num_trainable_vars = len(variables_lib.get_trainable_variables())
|
||||
|
||||
g_opt = get_sync_optimizer()
|
||||
d_opt = get_sync_optimizer()
|
||||
train_ops = train.gan_train_ops(
|
||||
model,
|
||||
loss,
|
||||
generator_optimizer=g_opt,
|
||||
discriminator_optimizer=d_opt)
|
||||
self.assertTrue(isinstance(train_ops, namedtuples.GANTrainOps))
|
||||
# No new trainable variables should have been added.
|
||||
self.assertEqual(num_trainable_vars,
|
||||
len(variables_lib.get_trainable_variables()))
|
||||
|
||||
g_sync_init_op = g_opt.get_init_tokens_op(num_tokens=1)
|
||||
d_sync_init_op = d_opt.get_init_tokens_op(num_tokens=1)
|
||||
|
||||
# Check that update op is run properly.
|
||||
global_step = training_util.get_or_create_global_step()
|
||||
with self.test_session(use_gpu=True) as sess:
|
||||
variables.global_variables_initializer().run()
|
||||
variables.local_variables_initializer().run()
|
||||
|
||||
g_opt.chief_init_op.run()
|
||||
d_opt.chief_init_op.run()
|
||||
|
||||
gstep_before = global_step.eval()
|
||||
|
||||
# Start required queue runner for SyncReplicasOptimizer.
|
||||
coord = coordinator.Coordinator()
|
||||
g_threads = g_opt.get_chief_queue_runner().create_threads(sess, coord)
|
||||
d_threads = d_opt.get_chief_queue_runner().create_threads(sess, coord)
|
||||
|
||||
g_sync_init_op.run()
|
||||
d_sync_init_op.run()
|
||||
|
||||
train_ops.generator_train_op.eval()
|
||||
# Check that global step wasn't incremented.
|
||||
self.assertEqual(gstep_before, global_step.eval())
|
||||
|
||||
train_ops.discriminator_train_op.eval()
|
||||
# Check that global step wasn't incremented.
|
||||
self.assertEqual(gstep_before, global_step.eval())
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(g_threads + d_threads)
|
||||
|
||||
def test_sync_replicas_gan(self):
|
||||
self._test_sync_replicas_helper(create_gan_model)
|
||||
|
||||
def test_sync_replicas_callable_gan(self):
|
||||
self._test_sync_replicas_helper(create_callable_gan_model)
|
||||
|
||||
def test_sync_replicas_infogan(self):
|
||||
self._test_sync_replicas_helper(create_infogan_model)
|
||||
|
||||
def test_sync_replicas_callable_infogan(self):
|
||||
self._test_sync_replicas_helper(create_callable_infogan_model)
|
||||
|
||||
def test_sync_replicas_acgan(self):
|
||||
self._test_sync_replicas_helper(create_acgan_model)
|
||||
|
||||
def test_sync_replicas_callable_acgan(self):
|
||||
self._test_sync_replicas_helper(create_callable_acgan_model)
|
||||
|
||||
|
||||
class GANTrainTest(test.TestCase):
|
||||
"""Tests for `gan_train`."""
|
||||
|
||||
def _gan_train_ops(self, generator_add, discriminator_add):
|
||||
step = training_util.create_global_step()
|
||||
# Increment the global count every time a train op is run so we can count
|
||||
# the number of times they're run.
|
||||
# NOTE: `use_locking=True` is required to avoid race conditions with
|
||||
# joint training.
|
||||
train_ops = namedtuples.GANTrainOps(
|
||||
generator_train_op=step.assign_add(generator_add, use_locking=True),
|
||||
discriminator_train_op=step.assign_add(discriminator_add,
|
||||
use_locking=True),
|
||||
global_step_inc_op=step.assign_add(1))
|
||||
return train_ops
|
||||
|
||||
def _test_run_helper(self, create_gan_model_fn):
|
||||
random_seed.set_random_seed(1234)
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
|
||||
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
|
||||
|
||||
final_step = train.gan_train(
|
||||
train_ops,
|
||||
logdir='',
|
||||
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
|
||||
self.assertTrue(np.isscalar(final_step))
|
||||
self.assertEqual(2, final_step)
|
||||
|
||||
def test_run_gan(self):
|
||||
self._test_run_helper(create_gan_model)
|
||||
|
||||
def test_run_callable_gan(self):
|
||||
self._test_run_helper(create_callable_gan_model)
|
||||
|
||||
def test_run_infogan(self):
|
||||
self._test_run_helper(create_infogan_model)
|
||||
|
||||
def test_run_callable_infogan(self):
|
||||
self._test_run_helper(create_callable_infogan_model)
|
||||
|
||||
def test_run_acgan(self):
|
||||
self._test_run_helper(create_acgan_model)
|
||||
|
||||
def test_run_callable_acgan(self):
|
||||
self._test_run_helper(create_callable_acgan_model)
|
||||
|
||||
# Test multiple train steps.
|
||||
def _test_multiple_steps_helper(self, get_hooks_fn_fn):
|
||||
train_ops = self._gan_train_ops(generator_add=10, discriminator_add=100)
|
||||
train_steps = namedtuples.GANTrainSteps(
|
||||
generator_train_steps=3,
|
||||
discriminator_train_steps=4)
|
||||
final_step = train.gan_train(
|
||||
train_ops,
|
||||
get_hooks_fn=get_hooks_fn_fn(train_steps),
|
||||
logdir='',
|
||||
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=1)])
|
||||
|
||||
self.assertTrue(np.isscalar(final_step))
|
||||
self.assertEqual(1 + 3 * 10 + 4 * 100, final_step)
|
||||
|
||||
def test_multiple_steps_seq_train_steps(self):
|
||||
self._test_multiple_steps_helper(train.get_sequential_train_hooks)
|
||||
|
||||
def test_multiple_steps_efficient_seq_train_steps(self):
|
||||
self._test_multiple_steps_helper(train.get_joint_train_hooks)
|
||||
|
||||
def test_supervisor_run_gan_model_train_ops_multiple_steps(self):
|
||||
step = training_util.create_global_step()
|
||||
train_ops = namedtuples.GANTrainOps(
|
||||
generator_train_op=constant_op.constant(3.0),
|
||||
discriminator_train_op=constant_op.constant(2.0),
|
||||
global_step_inc_op=step.assign_add(1))
|
||||
train_steps = namedtuples.GANTrainSteps(
|
||||
generator_train_steps=3,
|
||||
discriminator_train_steps=4)
|
||||
|
||||
final_loss = slim_learning.train(
|
||||
train_op=train_ops,
|
||||
logdir='',
|
||||
global_step=step,
|
||||
number_of_steps=1,
|
||||
train_step_fn=train.get_sequential_train_steps(train_steps))
|
||||
self.assertTrue(np.isscalar(final_loss))
|
||||
self.assertEqual(17.0, final_loss)
|
||||
|
||||
|
||||
class PatchGANTest(test.TestCase):
|
||||
"""Tests that functions work on PatchGAN style output."""
|
||||
|
||||
def _test_patchgan_helper(self, create_gan_model_fn):
|
||||
"""Ensure that patch-based discriminators work end-to-end."""
|
||||
random_seed.set_random_seed(1234)
|
||||
model = create_gan_model_fn()
|
||||
loss = train.gan_loss(model)
|
||||
|
||||
g_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
d_opt = gradient_descent.GradientDescentOptimizer(1.0)
|
||||
train_ops = train.gan_train_ops(model, loss, g_opt, d_opt)
|
||||
|
||||
final_step = train.gan_train(
|
||||
train_ops,
|
||||
logdir='',
|
||||
hooks=[basic_session_run_hooks.StopAtStepHook(num_steps=2)])
|
||||
self.assertTrue(np.isscalar(final_step))
|
||||
self.assertEqual(2, final_step)
|
||||
|
||||
def test_patchgan_gan(self):
|
||||
self._test_patchgan_helper(create_gan_model)
|
||||
|
||||
def test_patchgan_callable_gan(self):
|
||||
self._test_patchgan_helper(create_callable_gan_model)
|
||||
|
||||
def test_patchgan_infogan(self):
|
||||
self._test_patchgan_helper(create_infogan_model)
|
||||
|
||||
def test_patchgan_callable_infogan(self):
|
||||
self._test_patchgan_helper(create_callable_infogan_model)
|
||||
|
||||
def test_patchgan_acgan(self):
|
||||
self._test_patchgan_helper(create_acgan_model)
|
||||
|
||||
def test_patchgan_callable_acgan(self):
|
||||
self._test_patchgan_helper(create_callable_acgan_model)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user