Migrate GANEstimator to opensource.
PiperOrigin-RevId: 170597778
This commit is contained in:
parent
90dd85eed6
commit
f5f24f9857
tensorflow/contrib
@ -370,6 +370,8 @@ add_python_module("tensorflow/contrib/gan/python/eval")
|
||||
add_python_module("tensorflow/contrib/gan/python/eval/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/features")
|
||||
add_python_module("tensorflow/contrib/gan/python/features/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/estimator")
|
||||
add_python_module("tensorflow/contrib/gan/python/estimator/python")
|
||||
add_python_module("tensorflow/contrib/gan/python/losses")
|
||||
add_python_module("tensorflow/contrib/gan/python/losses/python")
|
||||
add_python_module("tensorflow/contrib/graph_editor")
|
||||
|
@ -14,6 +14,7 @@ py_library(
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":estimator",
|
||||
":eval",
|
||||
":features",
|
||||
":losses",
|
||||
@ -86,6 +87,17 @@ py_library(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "estimator",
|
||||
srcs = ["python/estimator/__init__.py"],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gan_estimator",
|
||||
":head",
|
||||
"//tensorflow/python:util",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "losses",
|
||||
srcs = ["python/losses/__init__.py"],
|
||||
@ -369,6 +381,89 @@ py_test(
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "head",
|
||||
srcs = [
|
||||
"python/estimator/python/head.py",
|
||||
"python/estimator/python/head_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":namedtuples",
|
||||
":train",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python/estimator:head",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "head_test",
|
||||
srcs = ["python/estimator/python/head_test.py"],
|
||||
shard_count = 1,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":head",
|
||||
":namedtuples",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_library(
|
||||
name = "gan_estimator",
|
||||
srcs = [
|
||||
"python/estimator/python/gan_estimator.py",
|
||||
"python/estimator/python/gan_estimator_impl.py",
|
||||
],
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":head",
|
||||
":namedtuples",
|
||||
":summaries",
|
||||
":train",
|
||||
"//tensorflow/contrib/framework:framework_py",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
"//tensorflow/python/estimator",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
],
|
||||
)
|
||||
|
||||
py_test(
|
||||
name = "gan_estimator_test",
|
||||
srcs = ["python/estimator/python/gan_estimator_test.py"],
|
||||
shard_count = 1,
|
||||
srcs_version = "PY2AND3",
|
||||
deps = [
|
||||
":gan_estimator",
|
||||
":namedtuples",
|
||||
":tuple_losses",
|
||||
"//tensorflow/contrib/layers:layers_py",
|
||||
"//tensorflow/contrib/learn",
|
||||
"//tensorflow/core:protos_all_py",
|
||||
"//tensorflow/python:array_ops",
|
||||
"//tensorflow/python:client_testlib",
|
||||
"//tensorflow/python:control_flow_ops",
|
||||
"//tensorflow/python:dtypes",
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:parsing_ops",
|
||||
"//tensorflow/python:summary",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python/estimator:head",
|
||||
"//tensorflow/python/estimator:model_fn",
|
||||
"//tensorflow/python/estimator:numpy_io",
|
||||
"//third_party/py/numpy",
|
||||
"@six_archive//:six",
|
||||
],
|
||||
)
|
||||
|
||||
filegroup(
|
||||
name = "all_files",
|
||||
srcs = glob(
|
||||
|
@ -19,6 +19,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Collapse TFGAN into a tiered namespace.
|
||||
from tensorflow.contrib.gan.python import estimator
|
||||
from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin
|
||||
from tensorflow.contrib.gan.python import features
|
||||
from tensorflow.contrib.gan.python import losses
|
||||
@ -33,6 +34,7 @@ from tensorflow.contrib.gan.python.train import *
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'estimator',
|
||||
'eval',
|
||||
'features',
|
||||
'losses',
|
||||
|
36
tensorflow/contrib/gan/python/estimator/__init__.py
Normal file
36
tensorflow/contrib/gan/python/estimator/__init__.py
Normal file
@ -0,0 +1,36 @@
|
||||
# Copyright 2016 Google Inc. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""TFGAN grouped API. Please see README.md for details and usage."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
# Collapse `estimator` into a single namespace.
|
||||
# pylint: disable=unused-import,wildcard-import
|
||||
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
|
||||
from tensorflow.contrib.gan.python.estimator.python import head
|
||||
|
||||
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
|
||||
from tensorflow.contrib.gan.python.estimator.python.head import *
|
||||
# pylint: enable=unused-import,wildcard-import
|
||||
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
_allowed_symbols = [
|
||||
'gan_estimator',
|
||||
'head',
|
||||
] + gan_estimator.__all__ + head.__all__
|
||||
remove_undocumented(__name__, _allowed_symbols)
|
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""`tf.Learn` components for `GANEstimator`."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.estimator.python.gan_estimator_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = gan_estimator_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
@ -0,0 +1,273 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A TFGAN-backed GAN Estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import enum
|
||||
|
||||
from tensorflow.contrib.framework.python.ops import variables as variable_lib
|
||||
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
|
||||
from tensorflow.contrib.gan.python import train as tfgan_train
|
||||
from tensorflow.contrib.gan.python.estimator.python import head as head_lib
|
||||
from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
|
||||
from tensorflow.python.estimator import estimator
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
|
||||
|
||||
__all__ = [
|
||||
'GANEstimator',
|
||||
'SummaryType'
|
||||
]
|
||||
|
||||
|
||||
class SummaryType(enum.IntEnum):
|
||||
NONE = 0
|
||||
VARIABLES = 1
|
||||
IMAGES = 2
|
||||
IMAGE_COMPARISON = 3
|
||||
|
||||
|
||||
_summary_type_map = {
|
||||
SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries,
|
||||
SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries,
|
||||
SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long
|
||||
}
|
||||
|
||||
|
||||
# TODO(joelshor): For now, this only supports 1:1 generator:discriminator
|
||||
# training sequentially. Find a nice way to expose options to the user without
|
||||
# exposing internals.
|
||||
class GANEstimator(estimator.Estimator):
|
||||
"""An estimator for Generative Adversarial Networks (GANs).
|
||||
|
||||
This Estimator is backed by TFGAN.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
tfgan = tf.contrib.gan
|
||||
|
||||
# See TFGAN's `train.py` for a description of the generator and
|
||||
# discriminator API.
|
||||
def generator_fn(generator_inputs):
|
||||
...
|
||||
return generated_data
|
||||
|
||||
def discriminator_fn(data, conditioning):
|
||||
...
|
||||
return logits
|
||||
|
||||
# Create GAN estimator.
|
||||
gan_estimator = estimator.GANEstimator(
|
||||
model_dir,
|
||||
generator_fn=generator_fn,
|
||||
discriminator_fn=discriminator_fn,
|
||||
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
|
||||
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
|
||||
generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5),
|
||||
discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5))
|
||||
|
||||
# Train estimator.
|
||||
gan_estimator.train(train_input_fn, steps)
|
||||
|
||||
# Evaluate resulting estimator.
|
||||
gan_estimator.evaluate(eval_input_fn)
|
||||
|
||||
# Generate samples from generator.
|
||||
predictions = np.array([
|
||||
x for x in gan_estimator.predict(predict_input_fn)])
|
||||
```
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model_dir=None,
|
||||
generator_fn=None,
|
||||
discriminator_fn=None,
|
||||
generator_loss_fn=None,
|
||||
discriminator_loss_fn=None,
|
||||
generator_optimizer=None,
|
||||
discriminator_optimizer=None,
|
||||
add_summaries=None,
|
||||
use_loss_summaries=True,
|
||||
config=None):
|
||||
"""Initializes a GANEstimator instance.
|
||||
|
||||
Args:
|
||||
model_dir: Directory to save model parameters, graph and etc. This can
|
||||
also be used to load checkpoints from the directory into a estimator
|
||||
to continue training a previously saved model.
|
||||
generator_fn: A python function that takes a Tensor, Tensor list, or
|
||||
Tensor dictionary as inputs and returns the outputs of the GAN
|
||||
generator. See `TFGAN` for more details and examples.
|
||||
discriminator_fn: A python function that takes the output of
|
||||
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
|
||||
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
|
||||
and examples.
|
||||
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
|
||||
tuple.
|
||||
discriminator_loss_fn: The loss function on the discriminator. Takes a
|
||||
`GANModel` tuple.
|
||||
generator_optimizer: The optimizer for generator updates, or a function
|
||||
that takes no arguments and returns an optimizer. This function will
|
||||
be called when the default graph is the `GANEstimator`'s graph, so
|
||||
utilities like `tf.contrib.framework.get_or_create_global_step` will
|
||||
work.
|
||||
discriminator_optimizer: Same as `generator_optimizer`, but for the
|
||||
discriminator updates.
|
||||
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
|
||||
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
|
||||
If `None`, uses defaults.
|
||||
config: `RunConfig` object to configure the runtime settings.
|
||||
"""
|
||||
# TODO(joelshor): Explicitly validate inputs.
|
||||
|
||||
def _model_fn(features, labels, mode):
|
||||
gopt = (generator_optimizer() if callable(generator_optimizer) else
|
||||
generator_optimizer)
|
||||
dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
|
||||
else discriminator_optimizer)
|
||||
gan_head = head_lib.gan_head(
|
||||
generator_loss_fn, discriminator_loss_fn, gopt, dopt,
|
||||
use_loss_summaries)
|
||||
return _gan_model_fn(
|
||||
features, labels, mode, generator_fn, discriminator_fn, gan_head,
|
||||
add_summaries)
|
||||
|
||||
super(GANEstimator, self).__init__(
|
||||
model_fn=_model_fn, model_dir=model_dir, config=config)
|
||||
|
||||
|
||||
def _use_check_shapes(real_data):
|
||||
"""Determines whether TFGAN should check Tensor shapes."""
|
||||
return isinstance(real_data, ops.Tensor)
|
||||
|
||||
|
||||
def _gan_model_fn(
|
||||
features,
|
||||
labels,
|
||||
mode,
|
||||
generator_fn,
|
||||
discriminator_fn,
|
||||
head,
|
||||
add_summaries=None,
|
||||
generator_scope_name='Generator'):
|
||||
"""The `model_fn` for the GAN estimator.
|
||||
|
||||
We make the following convention:
|
||||
features -> TFGAN's `generator_inputs`
|
||||
labels -> TFGAN's `real_data`
|
||||
|
||||
Args:
|
||||
features: A dictionary to feed to generator. In the unconditional case,
|
||||
this might be just `noise`. In the conditional GAN case, this
|
||||
might be the generator's conditioning. The `generator_fn` determines
|
||||
what the required keys are.
|
||||
labels: Real data. Can be any structure, as long as `discriminator_fn`
|
||||
can accept it for the first argument.
|
||||
mode: Defines whether this is training, evaluation or prediction.
|
||||
See `ModeKeys`.
|
||||
generator_fn: A python lambda that takes `generator_inputs` as inputs and
|
||||
returns the outputs of the GAN generator.
|
||||
discriminator_fn: A python lambda that takes `real_data`/`generated data`
|
||||
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
|
||||
head: A `Head` instance suitable for GANs.
|
||||
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
|
||||
generator_scope_name: The name of the generator scope. We need this to be
|
||||
the same for GANModels produced by TFGAN's `train.gan_model` and the
|
||||
manually constructed ones for predictions.
|
||||
|
||||
Returns:
|
||||
`ModelFnOps`
|
||||
|
||||
Raises:
|
||||
ValueError: If `labels` isn't `None` during prediction.
|
||||
"""
|
||||
real_data = labels
|
||||
generator_inputs = features
|
||||
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
gan_model = _make_train_gan_model(
|
||||
generator_fn, discriminator_fn, real_data, generator_inputs,
|
||||
generator_scope_name, add_summaries)
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
gan_model = _make_eval_gan_model(
|
||||
generator_fn, discriminator_fn, real_data, generator_inputs,
|
||||
generator_scope_name, add_summaries)
|
||||
else:
|
||||
if real_data is not None:
|
||||
raise ValueError('`labels` must be `None` when mode is `predict`. '
|
||||
'Instead, found %s' % real_data)
|
||||
gan_model = _make_prediction_gan_model(
|
||||
generator_inputs, generator_fn, generator_scope_name)
|
||||
|
||||
return head.create_estimator_spec(
|
||||
features=None,
|
||||
mode=mode,
|
||||
logits=gan_model,
|
||||
labels=None)
|
||||
|
||||
|
||||
def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
|
||||
generator_inputs, generator_scope, add_summaries):
|
||||
"""Make a `GANModel` for training."""
|
||||
gan_model = tfgan_train.gan_model(
|
||||
generator_fn,
|
||||
discriminator_fn,
|
||||
real_data,
|
||||
generator_inputs,
|
||||
generator_scope=generator_scope,
|
||||
check_shapes=_use_check_shapes(real_data))
|
||||
if add_summaries:
|
||||
if not isinstance(add_summaries, (tuple, list)):
|
||||
add_summaries = [add_summaries]
|
||||
with ops.name_scope(''):
|
||||
for summary_type in add_summaries:
|
||||
_summary_type_map[summary_type](gan_model)
|
||||
|
||||
return gan_model
|
||||
|
||||
|
||||
def _make_eval_gan_model(generator_fn, discriminator_fn, real_data,
|
||||
generator_inputs, generator_scope, add_summaries):
|
||||
"""Make a `GANModel` for evaluation."""
|
||||
return _make_train_gan_model(generator_fn, discriminator_fn, real_data,
|
||||
generator_inputs, generator_scope, add_summaries)
|
||||
|
||||
|
||||
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
|
||||
"""Make a `GANModel` from just the generator."""
|
||||
with variable_scope.variable_scope(generator_scope) as gen_scope:
|
||||
generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access
|
||||
generated_data = generator_fn(generator_inputs)
|
||||
generator_variables = variable_lib.get_trainable_variables(gen_scope)
|
||||
|
||||
return tfgan_tuples.GANModel(
|
||||
generator_inputs,
|
||||
generated_data,
|
||||
generator_variables,
|
||||
gen_scope,
|
||||
generator_fn,
|
||||
real_data=None,
|
||||
discriminator_real_outputs=None,
|
||||
discriminator_gen_outputs=None,
|
||||
discriminator_variables=None,
|
||||
discriminator_scope=None,
|
||||
discriminator_fn=None)
|
@ -0,0 +1,327 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TFGAN's estimator.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import shutil
|
||||
import tempfile
|
||||
|
||||
import numpy as np
|
||||
import six
|
||||
|
||||
from tensorflow.contrib import layers
|
||||
from tensorflow.contrib.gan.python import namedtuples
|
||||
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator
|
||||
from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses
|
||||
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
|
||||
from tensorflow.core.example import example_pb2
|
||||
from tensorflow.core.example import feature_pb2
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.canned import head as head_lib
|
||||
from tensorflow.python.estimator.inputs import numpy_io
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.summary.writer import writer_cache
|
||||
from tensorflow.python.training import input as input_lib
|
||||
from tensorflow.python.training import learning_rate_decay
|
||||
from tensorflow.python.training import monitored_session
|
||||
from tensorflow.python.training import training
|
||||
from tensorflow.python.training import training_util
|
||||
|
||||
|
||||
def generator_fn(noise_dict):
|
||||
noise = noise_dict['x']
|
||||
return layers.fully_connected(noise, noise.shape[1].value)
|
||||
|
||||
|
||||
def discriminator_fn(data, _):
|
||||
return layers.fully_connected(data, 1)
|
||||
|
||||
|
||||
def mock_head(testcase, expected_generator_inputs, expected_real_data,
|
||||
generator_scope_name):
|
||||
"""Returns a mock head that validates logits values and variable names."""
|
||||
discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults
|
||||
generator_var_names = set([
|
||||
'%s/fully_connected/weights:0' % generator_scope_name,
|
||||
'%s/fully_connected/biases:0' % generator_scope_name])
|
||||
discriminator_var_names = set([
|
||||
'%s/fully_connected/weights:0' % discriminator_scope_name,
|
||||
'%s/fully_connected/biases:0' % discriminator_scope_name])
|
||||
|
||||
def _create_estimator_spec(features, mode, logits, labels):
|
||||
gan_model = logits # renaming for clarity
|
||||
is_predict = mode == model_fn_lib.ModeKeys.PREDICT
|
||||
testcase.assertIsNone(features)
|
||||
testcase.assertIsNone(labels)
|
||||
testcase.assertIsInstance(gan_model, namedtuples.GANModel)
|
||||
|
||||
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
|
||||
expected_var_names = (generator_var_names if is_predict else
|
||||
generator_var_names | discriminator_var_names)
|
||||
testcase.assertItemsEqual(expected_var_names,
|
||||
[var.name for var in trainable_vars])
|
||||
|
||||
assertions = []
|
||||
def _or_none(x):
|
||||
return None if is_predict else x
|
||||
testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs)
|
||||
# TODO(joelshor): Add check on `generated_data`.
|
||||
testcase.assertItemsEqual(
|
||||
generator_var_names,
|
||||
set([x.name for x in gan_model.generator_variables]))
|
||||
testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name)
|
||||
testcase.assertEqual(generator_fn, gan_model.generator_fn)
|
||||
testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data)
|
||||
# TODO(joelshor): Add check on `discriminator_real_outputs`.
|
||||
# TODO(joelshor): Add check on `discriminator_gen_outputs`.
|
||||
if is_predict:
|
||||
testcase.assertIsNone(gan_model.discriminator_scope)
|
||||
else:
|
||||
testcase.assertEqual(discriminator_scope_name,
|
||||
gan_model.discriminator_scope.name)
|
||||
testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn)
|
||||
|
||||
with ops.control_dependencies(assertions):
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode, loss=array_ops.zeros([]),
|
||||
train_op=control_flow_ops.no_op(), training_hooks=[])
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode, predictions=gan_model.generated_data,
|
||||
loss=array_ops.zeros([]))
|
||||
elif mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=mode, predictions=gan_model.generated_data)
|
||||
else:
|
||||
testcase.fail('Invalid mode: {}'.format(mode))
|
||||
|
||||
head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
|
||||
head.create_estimator_spec = test.mock.MagicMock(
|
||||
wraps=_create_estimator_spec)
|
||||
|
||||
return head
|
||||
|
||||
|
||||
class GANModelFnTest(test.TestCase):
|
||||
"""Tests that _gan_model_fn passes expected logits to mock head."""
|
||||
|
||||
def setUp(self):
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
if self._model_dir:
|
||||
writer_cache.FileWriterCache.clear()
|
||||
shutil.rmtree(self._model_dir)
|
||||
|
||||
def _test_logits_helper(self, mode):
|
||||
"""Tests that the expected logits are passed to mock head."""
|
||||
with ops.Graph().as_default():
|
||||
training_util.get_or_create_global_step()
|
||||
generator_inputs = {'x': array_ops.zeros([5, 4])}
|
||||
real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
|
||||
array_ops.zeros([5, 4]))
|
||||
generator_scope_name = 'generator'
|
||||
head = mock_head(self,
|
||||
expected_generator_inputs=generator_inputs,
|
||||
expected_real_data=real_data,
|
||||
generator_scope_name=generator_scope_name)
|
||||
estimator_spec = estimator._gan_model_fn(
|
||||
features=generator_inputs,
|
||||
labels=real_data,
|
||||
mode=mode,
|
||||
generator_fn=generator_fn,
|
||||
discriminator_fn=discriminator_fn,
|
||||
generator_scope_name=generator_scope_name,
|
||||
head=head)
|
||||
with monitored_session.MonitoredTrainingSession(
|
||||
checkpoint_dir=self._model_dir) as sess:
|
||||
if mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
sess.run(estimator_spec.train_op)
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
sess.run(estimator_spec.loss)
|
||||
elif mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
sess.run(estimator_spec.predictions)
|
||||
else:
|
||||
self.fail('Invalid mode: {}'.format(mode))
|
||||
|
||||
def test_logits_predict(self):
|
||||
self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT)
|
||||
|
||||
def test_logits_eval(self):
|
||||
self._test_logits_helper(model_fn_lib.ModeKeys.EVAL)
|
||||
|
||||
def test_logits_train(self):
|
||||
self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN)
|
||||
|
||||
|
||||
# TODO(joelshor): Add pandas test.
|
||||
class GANEstimatorIntegrationTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
self._model_dir = tempfile.mkdtemp()
|
||||
|
||||
def tearDown(self):
|
||||
if self._model_dir:
|
||||
writer_cache.FileWriterCache.clear()
|
||||
shutil.rmtree(self._model_dir)
|
||||
|
||||
def _test_complete_flow(
|
||||
self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size,
|
||||
lr_decay=False):
|
||||
def make_opt():
|
||||
gstep = training_util.get_or_create_global_step()
|
||||
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
|
||||
return training.GradientDescentOptimizer(lr)
|
||||
|
||||
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
|
||||
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
|
||||
est = estimator.GANEstimator(
|
||||
generator_fn=generator_fn,
|
||||
discriminator_fn=discriminator_fn,
|
||||
generator_loss_fn=losses.wasserstein_generator_loss,
|
||||
discriminator_loss_fn=losses.wasserstein_discriminator_loss,
|
||||
generator_optimizer=gopt,
|
||||
discriminator_optimizer=dopt,
|
||||
model_dir=self._model_dir)
|
||||
|
||||
# TRAIN
|
||||
num_steps = 10
|
||||
est.train(train_input_fn, steps=num_steps)
|
||||
|
||||
# EVALUTE
|
||||
scores = est.evaluate(eval_input_fn)
|
||||
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
|
||||
self.assertIn('loss', six.iterkeys(scores))
|
||||
|
||||
# PREDICT
|
||||
predictions = np.array([x for x in est.predict(predict_input_fn)])
|
||||
|
||||
self.assertAllEqual(prediction_size, predictions.shape)
|
||||
|
||||
def test_numpy_input_fn(self):
|
||||
"""Tests complete flow with numpy_input_fn."""
|
||||
input_dim = 4
|
||||
batch_size = 5
|
||||
data = np.zeros([batch_size, input_dim])
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
predict_input_fn=predict_input_fn,
|
||||
prediction_size=[batch_size, input_dim])
|
||||
|
||||
def test_numpy_input_fn_lrdecay(self):
|
||||
"""Tests complete flow with numpy_input_fn."""
|
||||
input_dim = 4
|
||||
batch_size = 5
|
||||
data = np.zeros([batch_size, input_dim])
|
||||
train_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
num_epochs=None,
|
||||
shuffle=True)
|
||||
eval_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
y=data,
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
predict_input_fn = numpy_io.numpy_input_fn(
|
||||
x={'x': data},
|
||||
batch_size=batch_size,
|
||||
shuffle=False)
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=train_input_fn,
|
||||
eval_input_fn=eval_input_fn,
|
||||
predict_input_fn=predict_input_fn,
|
||||
prediction_size=[batch_size, input_dim],
|
||||
lr_decay=True)
|
||||
|
||||
def test_input_fn_from_parse_example(self):
|
||||
"""Tests complete flow with input_fn constructed from parse_example."""
|
||||
input_dim = 4
|
||||
batch_size = 6
|
||||
data = np.zeros([batch_size, input_dim])
|
||||
|
||||
serialized_examples = []
|
||||
for datum in data:
|
||||
example = example_pb2.Example(features=feature_pb2.Features(
|
||||
feature={
|
||||
'x': feature_pb2.Feature(
|
||||
float_list=feature_pb2.FloatList(value=datum)),
|
||||
'y': feature_pb2.Feature(
|
||||
float_list=feature_pb2.FloatList(value=datum)),
|
||||
}))
|
||||
serialized_examples.append(example.SerializeToString())
|
||||
|
||||
feature_spec = {
|
||||
'x': parsing_ops.FixedLenFeature([input_dim], dtypes.float32),
|
||||
'y': parsing_ops.FixedLenFeature([input_dim], dtypes.float32),
|
||||
}
|
||||
def _train_input_fn():
|
||||
feature_map = parsing_ops.parse_example(
|
||||
serialized_examples, feature_spec)
|
||||
_, features = graph_io.queue_parsed_features(feature_map)
|
||||
labels = features.pop('y')
|
||||
return features, labels
|
||||
def _eval_input_fn():
|
||||
feature_map = parsing_ops.parse_example(
|
||||
input_lib.limit_epochs(serialized_examples, num_epochs=1),
|
||||
feature_spec)
|
||||
_, features = graph_io.queue_parsed_features(feature_map)
|
||||
labels = features.pop('y')
|
||||
return features, labels
|
||||
def _predict_input_fn():
|
||||
feature_map = parsing_ops.parse_example(
|
||||
input_lib.limit_epochs(serialized_examples, num_epochs=1),
|
||||
feature_spec)
|
||||
_, features = graph_io.queue_parsed_features(feature_map)
|
||||
features.pop('y')
|
||||
return features, None
|
||||
|
||||
self._test_complete_flow(
|
||||
train_input_fn=_train_input_fn,
|
||||
eval_input_fn=_eval_input_fn,
|
||||
predict_input_fn=_predict_input_fn,
|
||||
prediction_size=[batch_size, input_dim])
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
28
tensorflow/contrib/gan/python/estimator/python/head.py
Normal file
28
tensorflow/contrib/gan/python/estimator/python/head.py
Normal file
@ -0,0 +1,28 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""`tf.Learn` components for `GANEstimator`'s loss."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python.estimator.python import head_impl
|
||||
# pylint: disable=wildcard-import
|
||||
from tensorflow.contrib.gan.python.estimator.python.head_impl import *
|
||||
# pylint: enable=wildcard-import
|
||||
from tensorflow.python.util.all_util import remove_undocumented
|
||||
|
||||
__all__ = head_impl.__all__
|
||||
remove_undocumented(__name__, __all__)
|
206
tensorflow/contrib/gan/python/estimator/python/head_impl.py
Normal file
206
tensorflow/contrib/gan/python/estimator/python/head_impl.py
Normal file
@ -0,0 +1,206 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""A TFGAN-backed GAN Estimator."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
|
||||
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
|
||||
from tensorflow.contrib.gan.python import train as tfgan_train
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.estimator.canned import head
|
||||
from tensorflow.python.framework import ops
|
||||
|
||||
__all__ = [
|
||||
'GANHead',
|
||||
'gan_head',
|
||||
]
|
||||
|
||||
|
||||
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
|
||||
discriminator_optimizer, use_loss_summaries=True,
|
||||
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
|
||||
name=None):
|
||||
"""Creates a `GANHead`.
|
||||
|
||||
Args:
|
||||
generator_loss_fn: A TFGAN loss function for the generator. Takes a
|
||||
`GANModel` and returns a scalar.
|
||||
discriminator_loss_fn: Same as `generator_loss_fn`, but for the
|
||||
discriminator.
|
||||
generator_optimizer: The optimizer for generator updates.
|
||||
discriminator_optimizer: Same as `generator_optimizer`, but for the
|
||||
discriminator updates.
|
||||
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
|
||||
If `None`, uses defaults.
|
||||
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
|
||||
of hooks.
|
||||
name: name of the head. If provided, summary and metrics keys will be
|
||||
suffixed by `"/" + name`.
|
||||
|
||||
Returns:
|
||||
An instance of `GANHead`.
|
||||
"""
|
||||
return GANHead(generator_loss_fn=generator_loss_fn,
|
||||
discriminator_loss_fn=discriminator_loss_fn,
|
||||
generator_optimizer=generator_optimizer,
|
||||
discriminator_optimizer=discriminator_optimizer,
|
||||
use_loss_summaries=use_loss_summaries,
|
||||
get_hooks_fn=get_hooks_fn,
|
||||
name=name)
|
||||
|
||||
|
||||
class GANHead(head._Head): # pylint: disable=protected-access
|
||||
"""`Head` for a GAN."""
|
||||
|
||||
def __init__(self, generator_loss_fn, discriminator_loss_fn,
|
||||
generator_optimizer, discriminator_optimizer,
|
||||
use_loss_summaries=True,
|
||||
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
|
||||
name=None):
|
||||
"""`Head` for GAN training.
|
||||
|
||||
Args:
|
||||
generator_loss_fn: A TFGAN loss function for the generator. Takes a
|
||||
`GANModel` and returns a scalar.
|
||||
discriminator_loss_fn: Same as `generator_loss_fn`, but for the
|
||||
discriminator.
|
||||
generator_optimizer: The optimizer for generator updates.
|
||||
discriminator_optimizer: Same as `generator_optimizer`, but for the
|
||||
discriminator updates.
|
||||
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
|
||||
If `None`, uses defaults.
|
||||
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
|
||||
of hooks.
|
||||
name: name of the head. If provided, summary and metrics keys will be
|
||||
suffixed by `"/" + name`.
|
||||
"""
|
||||
# TODO(joelshor): Validate inputs.
|
||||
|
||||
if use_loss_summaries in [True, False]:
|
||||
generator_loss_fn = functools.partial(
|
||||
generator_loss_fn, add_summaries=use_loss_summaries)
|
||||
discriminator_loss_fn = functools.partial(
|
||||
discriminator_loss_fn, add_summaries=use_loss_summaries)
|
||||
self._generator_loss_fn = generator_loss_fn
|
||||
self._discriminator_loss_fn = discriminator_loss_fn
|
||||
self._generator_optimizer = generator_optimizer
|
||||
self._discriminator_optimizer = discriminator_optimizer
|
||||
self._get_hooks_fn = get_hooks_fn
|
||||
|
||||
@property
|
||||
def name(self):
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def logits_dimension(self):
|
||||
return None
|
||||
|
||||
def create_loss(self, features, mode, logits, labels):
|
||||
"""Returns a GANLoss tuple from the provided GANModel.
|
||||
|
||||
See `Head` for more details.
|
||||
|
||||
Args:
|
||||
features: Input `dict` of `Tensor` objects. Unused.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
logits: A GANModel tuple.
|
||||
labels: Must be `None`.
|
||||
|
||||
Returns:
|
||||
A GANLoss tuple.
|
||||
|
||||
"""
|
||||
_validate_logits_and_labels(logits, labels)
|
||||
del mode, labels, features # unused for this head.
|
||||
gan_model = logits # rename variable for clarity
|
||||
return tfgan_tuples.GANLoss(
|
||||
generator_loss=self._generator_loss_fn(gan_model),
|
||||
discriminator_loss=self._discriminator_loss_fn(gan_model))
|
||||
|
||||
def create_estimator_spec(
|
||||
self, features, mode, logits, labels=None,
|
||||
train_op_fn=tfgan_train.gan_train_ops):
|
||||
"""Returns `EstimatorSpec` that a model_fn can return.
|
||||
|
||||
See `Head` for more details.
|
||||
|
||||
Args:
|
||||
features: Must be `None`.
|
||||
mode: Estimator's `ModeKeys`.
|
||||
logits: A GANModel tuple.
|
||||
labels: Must be `None`.
|
||||
train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer,
|
||||
and discriminator optimizer, and returns a `GANTrainOps` tuple. For
|
||||
example, this function can come from TFGAN's `train.py` library, or can
|
||||
be custom.
|
||||
|
||||
Returns:
|
||||
`EstimatorSpec`.
|
||||
|
||||
Raises:
|
||||
ValueError: If `features` isn't `None`.
|
||||
ValueError: If `train_op_fn` isn't provided in train mode.
|
||||
"""
|
||||
_validate_logits_and_labels(logits, labels)
|
||||
if features is not None:
|
||||
raise ValueError('`features` should be `None`. Instead, found: %s' %
|
||||
features)
|
||||
gan_model = logits # rename variable for clarity
|
||||
with ops.name_scope('GANHead'):
|
||||
if mode == model_fn_lib.ModeKeys.PREDICT:
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=model_fn_lib.ModeKeys.PREDICT,
|
||||
predictions=gan_model.generated_data)
|
||||
elif mode == model_fn_lib.ModeKeys.EVAL:
|
||||
gan_loss = self.create_loss(
|
||||
features=None, mode=mode, logits=gan_model, labels=None)
|
||||
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
mode=model_fn_lib.ModeKeys.EVAL,
|
||||
predictions=gan_model.generated_data,
|
||||
loss=scalar_loss,
|
||||
# TODO(joelshor): Add metrics. If head name provided, append it to
|
||||
# metric keys.
|
||||
eval_metric_ops={})
|
||||
elif mode == model_fn_lib.ModeKeys.TRAIN:
|
||||
if train_op_fn is None:
|
||||
raise ValueError('train_op_fn can not be None.')
|
||||
gan_loss = self.create_loss(None, mode, gan_model, None)
|
||||
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
|
||||
train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer,
|
||||
self._discriminator_optimizer)
|
||||
training_hooks = self._get_hooks_fn(train_ops)
|
||||
return model_fn_lib.EstimatorSpec(
|
||||
loss=scalar_loss,
|
||||
mode=model_fn_lib.ModeKeys.TRAIN,
|
||||
train_op=train_ops.global_step_inc_op,
|
||||
training_hooks=training_hooks)
|
||||
else:
|
||||
raise ValueError('Mode not recognized: %s' % mode)
|
||||
|
||||
|
||||
def _validate_logits_and_labels(logits, labels):
|
||||
if labels is not None:
|
||||
raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must '
|
||||
'be `None`. Instead, found: %s' % labels)
|
||||
|
||||
if not isinstance(logits, tfgan_tuples.GANModel):
|
||||
raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must '
|
||||
'be an instnace of a `GANModel`. Instead, found: %s' %
|
||||
logits)
|
85
tensorflow/contrib/gan/python/estimator/python/head_test.py
Normal file
85
tensorflow/contrib/gan/python/estimator/python/head_test.py
Normal file
@ -0,0 +1,85 @@
|
||||
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
"""Tests for TFGAN's head.py."""
|
||||
|
||||
from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
|
||||
from tensorflow.contrib.gan.python.estimator.python import head
|
||||
|
||||
from tensorflow.python.estimator import model_fn as model_fn_lib
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.platform import test
|
||||
from tensorflow.python.training import training
|
||||
|
||||
|
||||
def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
|
||||
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
|
||||
gan_model.discriminator_gen_outputs)
|
||||
|
||||
|
||||
def get_gan_model():
|
||||
# TODO(joelshor): Find a better way of creating a variable scope.
|
||||
with variable_scope.variable_scope('generator') as gen_scope:
|
||||
gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
|
||||
with variable_scope.variable_scope('discriminator') as dis_scope:
|
||||
dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
|
||||
return tfgan_tuples.GANModel(
|
||||
generator_inputs=None,
|
||||
generated_data=array_ops.ones([3, 4]),
|
||||
generator_variables=[gen_var],
|
||||
generator_scope=gen_scope,
|
||||
generator_fn=None,
|
||||
real_data=None,
|
||||
discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
|
||||
discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
|
||||
discriminator_variables=[dis_var],
|
||||
discriminator_scope=dis_scope,
|
||||
discriminator_fn=None)
|
||||
|
||||
|
||||
class GANHeadTest(test.TestCase):
|
||||
|
||||
def setUp(self):
|
||||
super(GANHeadTest, self).setUp()
|
||||
self.gan_head = head.gan_head(
|
||||
generator_loss_fn=dummy_loss,
|
||||
discriminator_loss_fn=dummy_loss,
|
||||
generator_optimizer=training.GradientDescentOptimizer(1.0),
|
||||
discriminator_optimizer=training.GradientDescentOptimizer(1.0))
|
||||
self.assertTrue(isinstance(self.gan_head, head.GANHead))
|
||||
|
||||
def _test_modes_helper(self, mode):
|
||||
self.gan_head.create_estimator_spec(
|
||||
features=None,
|
||||
mode=mode,
|
||||
logits=get_gan_model())
|
||||
|
||||
def test_modes_predict(self):
|
||||
self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
|
||||
|
||||
def test_modes_eval(self):
|
||||
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)
|
||||
|
||||
def test_modes_train(self):
|
||||
self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
Loading…
Reference in New Issue
Block a user