Migrate GANEstimator to opensource.

PiperOrigin-RevId: 170597778
This commit is contained in:
A. Unique TensorFlower 2017-09-30 12:43:02 -07:00 committed by TensorFlower Gardener
parent 90dd85eed6
commit f5f24f9857
10 changed files with 1082 additions and 0 deletions

View File

@ -370,6 +370,8 @@ add_python_module("tensorflow/contrib/gan/python/eval")
add_python_module("tensorflow/contrib/gan/python/eval/python")
add_python_module("tensorflow/contrib/gan/python/features")
add_python_module("tensorflow/contrib/gan/python/features/python")
add_python_module("tensorflow/contrib/gan/python/estimator")
add_python_module("tensorflow/contrib/gan/python/estimator/python")
add_python_module("tensorflow/contrib/gan/python/losses")
add_python_module("tensorflow/contrib/gan/python/losses/python")
add_python_module("tensorflow/contrib/graph_editor")

View File

@ -14,6 +14,7 @@ py_library(
],
srcs_version = "PY2AND3",
deps = [
":estimator",
":eval",
":features",
":losses",
@ -86,6 +87,17 @@ py_library(
],
)
py_library(
name = "estimator",
srcs = ["python/estimator/__init__.py"],
srcs_version = "PY2AND3",
deps = [
":gan_estimator",
":head",
"//tensorflow/python:util",
],
)
py_library(
name = "losses",
srcs = ["python/losses/__init__.py"],
@ -369,6 +381,89 @@ py_test(
],
)
py_library(
name = "head",
srcs = [
"python/estimator/python/head.py",
"python/estimator/python/head_impl.py",
],
srcs_version = "PY2AND3",
deps = [
":namedtuples",
":train",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:model_fn",
],
)
py_test(
name = "head_test",
srcs = ["python/estimator/python/head_test.py"],
shard_count = 1,
srcs_version = "PY2AND3",
deps = [
":head",
":namedtuples",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:math_ops",
"//tensorflow/python:training",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator:model_fn",
],
)
py_library(
name = "gan_estimator",
srcs = [
"python/estimator/python/gan_estimator.py",
"python/estimator/python/gan_estimator_impl.py",
],
srcs_version = "PY2AND3",
deps = [
":head",
":namedtuples",
":summaries",
":train",
"//tensorflow/contrib/framework:framework_py",
"//tensorflow/python:framework_ops",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
"//tensorflow/python/estimator",
"//tensorflow/python/estimator:model_fn",
],
)
py_test(
name = "gan_estimator_test",
srcs = ["python/estimator/python/gan_estimator_test.py"],
shard_count = 1,
srcs_version = "PY2AND3",
deps = [
":gan_estimator",
":namedtuples",
":tuple_losses",
"//tensorflow/contrib/layers:layers_py",
"//tensorflow/contrib/learn",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:control_flow_ops",
"//tensorflow/python:dtypes",
"//tensorflow/python:framework_ops",
"//tensorflow/python:parsing_ops",
"//tensorflow/python:summary",
"//tensorflow/python:training",
"//tensorflow/python/estimator:head",
"//tensorflow/python/estimator:model_fn",
"//tensorflow/python/estimator:numpy_io",
"//third_party/py/numpy",
"@six_archive//:six",
],
)
filegroup(
name = "all_files",
srcs = glob(

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
# Collapse TFGAN into a tiered namespace.
from tensorflow.contrib.gan.python import estimator
from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin
from tensorflow.contrib.gan.python import features
from tensorflow.contrib.gan.python import losses
@ -33,6 +34,7 @@ from tensorflow.contrib.gan.python.train import *
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'estimator',
'eval',
'features',
'losses',

View File

@ -0,0 +1,36 @@
# Copyright 2016 Google Inc. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TFGAN grouped API. Please see README.md for details and usage."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
# Collapse `estimator` into a single namespace.
# pylint: disable=unused-import,wildcard-import
from tensorflow.contrib.gan.python.estimator.python import gan_estimator
from tensorflow.contrib.gan.python.estimator.python import head
from tensorflow.contrib.gan.python.estimator.python.gan_estimator import *
from tensorflow.contrib.gan.python.estimator.python.head import *
# pylint: enable=unused-import,wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = [
'gan_estimator',
'head',
] + gan_estimator.__all__ + head.__all__
remove_undocumented(__name__, _allowed_symbols)

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`tf.Learn` components for `GANEstimator`."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.estimator.python.gan_estimator_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = gan_estimator_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,273 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A TFGAN-backed GAN Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum
from tensorflow.contrib.framework.python.ops import variables as variable_lib
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.contrib.gan.python.estimator.python import head as head_lib
from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries
from tensorflow.python.estimator import estimator
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.framework import ops
from tensorflow.python.ops import variable_scope
__all__ = [
'GANEstimator',
'SummaryType'
]
class SummaryType(enum.IntEnum):
NONE = 0
VARIABLES = 1
IMAGES = 2
IMAGE_COMPARISON = 3
_summary_type_map = {
SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries,
SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries,
SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long
}
# TODO(joelshor): For now, this only supports 1:1 generator:discriminator
# training sequentially. Find a nice way to expose options to the user without
# exposing internals.
class GANEstimator(estimator.Estimator):
"""An estimator for Generative Adversarial Networks (GANs).
This Estimator is backed by TFGAN.
Example:
```python
import tensorflow as tf
tfgan = tf.contrib.gan
# See TFGAN's `train.py` for a description of the generator and
# discriminator API.
def generator_fn(generator_inputs):
...
return generated_data
def discriminator_fn(data, conditioning):
...
return logits
# Create GAN estimator.
gan_estimator = estimator.GANEstimator(
model_dir,
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_loss_fn=tfgan.losses.wasserstein_generator_loss,
discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss,
generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5),
discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5))
# Train estimator.
gan_estimator.train(train_input_fn, steps)
# Evaluate resulting estimator.
gan_estimator.evaluate(eval_input_fn)
# Generate samples from generator.
predictions = np.array([
x for x in gan_estimator.predict(predict_input_fn)])
```
"""
def __init__(self,
model_dir=None,
generator_fn=None,
discriminator_fn=None,
generator_loss_fn=None,
discriminator_loss_fn=None,
generator_optimizer=None,
discriminator_optimizer=None,
add_summaries=None,
use_loss_summaries=True,
config=None):
"""Initializes a GANEstimator instance.
Args:
model_dir: Directory to save model parameters, graph and etc. This can
also be used to load checkpoints from the directory into a estimator
to continue training a previously saved model.
generator_fn: A python function that takes a Tensor, Tensor list, or
Tensor dictionary as inputs and returns the outputs of the GAN
generator. See `TFGAN` for more details and examples.
discriminator_fn: A python function that takes the output of
`generator_fn` or real data in the GAN setup, and `generator_inputs`.
Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details
and examples.
generator_loss_fn: The loss function on the generator. Takes a `GANModel`
tuple.
discriminator_loss_fn: The loss function on the discriminator. Takes a
`GANModel` tuple.
generator_optimizer: The optimizer for generator updates, or a function
that takes no arguments and returns an optimizer. This function will
be called when the default graph is the `GANEstimator`'s graph, so
utilities like `tf.contrib.framework.get_or_create_global_step` will
work.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
config: `RunConfig` object to configure the runtime settings.
"""
# TODO(joelshor): Explicitly validate inputs.
def _model_fn(features, labels, mode):
gopt = (generator_optimizer() if callable(generator_optimizer) else
generator_optimizer)
dopt = (discriminator_optimizer() if callable(discriminator_optimizer)
else discriminator_optimizer)
gan_head = head_lib.gan_head(
generator_loss_fn, discriminator_loss_fn, gopt, dopt,
use_loss_summaries)
return _gan_model_fn(
features, labels, mode, generator_fn, discriminator_fn, gan_head,
add_summaries)
super(GANEstimator, self).__init__(
model_fn=_model_fn, model_dir=model_dir, config=config)
def _use_check_shapes(real_data):
"""Determines whether TFGAN should check Tensor shapes."""
return isinstance(real_data, ops.Tensor)
def _gan_model_fn(
features,
labels,
mode,
generator_fn,
discriminator_fn,
head,
add_summaries=None,
generator_scope_name='Generator'):
"""The `model_fn` for the GAN estimator.
We make the following convention:
features -> TFGAN's `generator_inputs`
labels -> TFGAN's `real_data`
Args:
features: A dictionary to feed to generator. In the unconditional case,
this might be just `noise`. In the conditional GAN case, this
might be the generator's conditioning. The `generator_fn` determines
what the required keys are.
labels: Real data. Can be any structure, as long as `discriminator_fn`
can accept it for the first argument.
mode: Defines whether this is training, evaluation or prediction.
See `ModeKeys`.
generator_fn: A python lambda that takes `generator_inputs` as inputs and
returns the outputs of the GAN generator.
discriminator_fn: A python lambda that takes `real_data`/`generated data`
and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
head: A `Head` instance suitable for GANs.
add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`.
generator_scope_name: The name of the generator scope. We need this to be
the same for GANModels produced by TFGAN's `train.gan_model` and the
manually constructed ones for predictions.
Returns:
`ModelFnOps`
Raises:
ValueError: If `labels` isn't `None` during prediction.
"""
real_data = labels
generator_inputs = features
if mode == model_fn_lib.ModeKeys.TRAIN:
gan_model = _make_train_gan_model(
generator_fn, discriminator_fn, real_data, generator_inputs,
generator_scope_name, add_summaries)
elif mode == model_fn_lib.ModeKeys.EVAL:
gan_model = _make_eval_gan_model(
generator_fn, discriminator_fn, real_data, generator_inputs,
generator_scope_name, add_summaries)
else:
if real_data is not None:
raise ValueError('`labels` must be `None` when mode is `predict`. '
'Instead, found %s' % real_data)
gan_model = _make_prediction_gan_model(
generator_inputs, generator_fn, generator_scope_name)
return head.create_estimator_spec(
features=None,
mode=mode,
logits=gan_model,
labels=None)
def _make_train_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries):
"""Make a `GANModel` for training."""
gan_model = tfgan_train.gan_model(
generator_fn,
discriminator_fn,
real_data,
generator_inputs,
generator_scope=generator_scope,
check_shapes=_use_check_shapes(real_data))
if add_summaries:
if not isinstance(add_summaries, (tuple, list)):
add_summaries = [add_summaries]
with ops.name_scope(''):
for summary_type in add_summaries:
_summary_type_map[summary_type](gan_model)
return gan_model
def _make_eval_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries):
"""Make a `GANModel` for evaluation."""
return _make_train_gan_model(generator_fn, discriminator_fn, real_data,
generator_inputs, generator_scope, add_summaries)
def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope):
"""Make a `GANModel` from just the generator."""
with variable_scope.variable_scope(generator_scope) as gen_scope:
generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access
generated_data = generator_fn(generator_inputs)
generator_variables = variable_lib.get_trainable_variables(gen_scope)
return tfgan_tuples.GANModel(
generator_inputs,
generated_data,
generator_variables,
gen_scope,
generator_fn,
real_data=None,
discriminator_real_outputs=None,
discriminator_gen_outputs=None,
discriminator_variables=None,
discriminator_scope=None,
discriminator_fn=None)

View File

@ -0,0 +1,327 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFGAN's estimator.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import shutil
import tempfile
import numpy as np
import six
from tensorflow.contrib import layers
from tensorflow.contrib.gan.python import namedtuples
from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator
from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses
from tensorflow.contrib.learn.python.learn.learn_io import graph_io
from tensorflow.core.example import example_pb2
from tensorflow.core.example import feature_pb2
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head as head_lib
from tensorflow.python.estimator.inputs import numpy_io
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.platform import test
from tensorflow.python.summary.writer import writer_cache
from tensorflow.python.training import input as input_lib
from tensorflow.python.training import learning_rate_decay
from tensorflow.python.training import monitored_session
from tensorflow.python.training import training
from tensorflow.python.training import training_util
def generator_fn(noise_dict):
noise = noise_dict['x']
return layers.fully_connected(noise, noise.shape[1].value)
def discriminator_fn(data, _):
return layers.fully_connected(data, 1)
def mock_head(testcase, expected_generator_inputs, expected_real_data,
generator_scope_name):
"""Returns a mock head that validates logits values and variable names."""
discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults
generator_var_names = set([
'%s/fully_connected/weights:0' % generator_scope_name,
'%s/fully_connected/biases:0' % generator_scope_name])
discriminator_var_names = set([
'%s/fully_connected/weights:0' % discriminator_scope_name,
'%s/fully_connected/biases:0' % discriminator_scope_name])
def _create_estimator_spec(features, mode, logits, labels):
gan_model = logits # renaming for clarity
is_predict = mode == model_fn_lib.ModeKeys.PREDICT
testcase.assertIsNone(features)
testcase.assertIsNone(labels)
testcase.assertIsInstance(gan_model, namedtuples.GANModel)
trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES)
expected_var_names = (generator_var_names if is_predict else
generator_var_names | discriminator_var_names)
testcase.assertItemsEqual(expected_var_names,
[var.name for var in trainable_vars])
assertions = []
def _or_none(x):
return None if is_predict else x
testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs)
# TODO(joelshor): Add check on `generated_data`.
testcase.assertItemsEqual(
generator_var_names,
set([x.name for x in gan_model.generator_variables]))
testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name)
testcase.assertEqual(generator_fn, gan_model.generator_fn)
testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data)
# TODO(joelshor): Add check on `discriminator_real_outputs`.
# TODO(joelshor): Add check on `discriminator_gen_outputs`.
if is_predict:
testcase.assertIsNone(gan_model.discriminator_scope)
else:
testcase.assertEqual(discriminator_scope_name,
gan_model.discriminator_scope.name)
testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn)
with ops.control_dependencies(assertions):
if mode == model_fn_lib.ModeKeys.TRAIN:
return model_fn_lib.EstimatorSpec(
mode=mode, loss=array_ops.zeros([]),
train_op=control_flow_ops.no_op(), training_hooks=[])
elif mode == model_fn_lib.ModeKeys.EVAL:
return model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data,
loss=array_ops.zeros([]))
elif mode == model_fn_lib.ModeKeys.PREDICT:
return model_fn_lib.EstimatorSpec(
mode=mode, predictions=gan_model.generated_data)
else:
testcase.fail('Invalid mode: {}'.format(mode))
head = test.mock.NonCallableMagicMock(spec=head_lib._Head)
head.create_estimator_spec = test.mock.MagicMock(
wraps=_create_estimator_spec)
return head
class GANModelFnTest(test.TestCase):
"""Tests that _gan_model_fn passes expected logits to mock head."""
def setUp(self):
self._model_dir = tempfile.mkdtemp()
def tearDown(self):
if self._model_dir:
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
def _test_logits_helper(self, mode):
"""Tests that the expected logits are passed to mock head."""
with ops.Graph().as_default():
training_util.get_or_create_global_step()
generator_inputs = {'x': array_ops.zeros([5, 4])}
real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else
array_ops.zeros([5, 4]))
generator_scope_name = 'generator'
head = mock_head(self,
expected_generator_inputs=generator_inputs,
expected_real_data=real_data,
generator_scope_name=generator_scope_name)
estimator_spec = estimator._gan_model_fn(
features=generator_inputs,
labels=real_data,
mode=mode,
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_scope_name=generator_scope_name,
head=head)
with monitored_session.MonitoredTrainingSession(
checkpoint_dir=self._model_dir) as sess:
if mode == model_fn_lib.ModeKeys.TRAIN:
sess.run(estimator_spec.train_op)
elif mode == model_fn_lib.ModeKeys.EVAL:
sess.run(estimator_spec.loss)
elif mode == model_fn_lib.ModeKeys.PREDICT:
sess.run(estimator_spec.predictions)
else:
self.fail('Invalid mode: {}'.format(mode))
def test_logits_predict(self):
self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT)
def test_logits_eval(self):
self._test_logits_helper(model_fn_lib.ModeKeys.EVAL)
def test_logits_train(self):
self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN)
# TODO(joelshor): Add pandas test.
class GANEstimatorIntegrationTest(test.TestCase):
def setUp(self):
self._model_dir = tempfile.mkdtemp()
def tearDown(self):
if self._model_dir:
writer_cache.FileWriterCache.clear()
shutil.rmtree(self._model_dir)
def _test_complete_flow(
self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size,
lr_decay=False):
def make_opt():
gstep = training_util.get_or_create_global_step()
lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9)
return training.GradientDescentOptimizer(lr)
gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0)
est = estimator.GANEstimator(
generator_fn=generator_fn,
discriminator_fn=discriminator_fn,
generator_loss_fn=losses.wasserstein_generator_loss,
discriminator_loss_fn=losses.wasserstein_discriminator_loss,
generator_optimizer=gopt,
discriminator_optimizer=dopt,
model_dir=self._model_dir)
# TRAIN
num_steps = 10
est.train(train_input_fn, steps=num_steps)
# EVALUTE
scores = est.evaluate(eval_input_fn)
self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP])
self.assertIn('loss', six.iterkeys(scores))
# PREDICT
predictions = np.array([x for x in est.predict(predict_input_fn)])
self.assertAllEqual(prediction_size, predictions.shape)
def test_numpy_input_fn(self):
"""Tests complete flow with numpy_input_fn."""
input_dim = 4
batch_size = 5
data = np.zeros([batch_size, input_dim])
train_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
y=data,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
y=data,
batch_size=batch_size,
shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
batch_size=batch_size,
shuffle=False)
self._test_complete_flow(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
prediction_size=[batch_size, input_dim])
def test_numpy_input_fn_lrdecay(self):
"""Tests complete flow with numpy_input_fn."""
input_dim = 4
batch_size = 5
data = np.zeros([batch_size, input_dim])
train_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
y=data,
batch_size=batch_size,
num_epochs=None,
shuffle=True)
eval_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
y=data,
batch_size=batch_size,
shuffle=False)
predict_input_fn = numpy_io.numpy_input_fn(
x={'x': data},
batch_size=batch_size,
shuffle=False)
self._test_complete_flow(
train_input_fn=train_input_fn,
eval_input_fn=eval_input_fn,
predict_input_fn=predict_input_fn,
prediction_size=[batch_size, input_dim],
lr_decay=True)
def test_input_fn_from_parse_example(self):
"""Tests complete flow with input_fn constructed from parse_example."""
input_dim = 4
batch_size = 6
data = np.zeros([batch_size, input_dim])
serialized_examples = []
for datum in data:
example = example_pb2.Example(features=feature_pb2.Features(
feature={
'x': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum)),
'y': feature_pb2.Feature(
float_list=feature_pb2.FloatList(value=datum)),
}))
serialized_examples.append(example.SerializeToString())
feature_spec = {
'x': parsing_ops.FixedLenFeature([input_dim], dtypes.float32),
'y': parsing_ops.FixedLenFeature([input_dim], dtypes.float32),
}
def _train_input_fn():
feature_map = parsing_ops.parse_example(
serialized_examples, feature_spec)
_, features = graph_io.queue_parsed_features(feature_map)
labels = features.pop('y')
return features, labels
def _eval_input_fn():
feature_map = parsing_ops.parse_example(
input_lib.limit_epochs(serialized_examples, num_epochs=1),
feature_spec)
_, features = graph_io.queue_parsed_features(feature_map)
labels = features.pop('y')
return features, labels
def _predict_input_fn():
feature_map = parsing_ops.parse_example(
input_lib.limit_epochs(serialized_examples, num_epochs=1),
feature_spec)
_, features = graph_io.queue_parsed_features(feature_map)
features.pop('y')
return features, None
self._test_complete_flow(
train_input_fn=_train_input_fn,
eval_input_fn=_eval_input_fn,
predict_input_fn=_predict_input_fn,
prediction_size=[batch_size, input_dim])
if __name__ == '__main__':
test.main()

View File

@ -0,0 +1,28 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""`tf.Learn` components for `GANEstimator`'s loss."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python.estimator.python import head_impl
# pylint: disable=wildcard-import
from tensorflow.contrib.gan.python.estimator.python.head_impl import *
# pylint: enable=wildcard-import
from tensorflow.python.util.all_util import remove_undocumented
__all__ = head_impl.__all__
remove_undocumented(__name__, __all__)

View File

@ -0,0 +1,206 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""A TFGAN-backed GAN Estimator."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python import train as tfgan_train
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.estimator.canned import head
from tensorflow.python.framework import ops
__all__ = [
'GANHead',
'gan_head',
]
def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer,
discriminator_optimizer, use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
name=None):
"""Creates a `GANHead`.
Args:
generator_loss_fn: A TFGAN loss function for the generator. Takes a
`GANModel` and returns a scalar.
discriminator_loss_fn: Same as `generator_loss_fn`, but for the
discriminator.
generator_optimizer: The optimizer for generator updates.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
of hooks.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
Returns:
An instance of `GANHead`.
"""
return GANHead(generator_loss_fn=generator_loss_fn,
discriminator_loss_fn=discriminator_loss_fn,
generator_optimizer=generator_optimizer,
discriminator_optimizer=discriminator_optimizer,
use_loss_summaries=use_loss_summaries,
get_hooks_fn=get_hooks_fn,
name=name)
class GANHead(head._Head): # pylint: disable=protected-access
"""`Head` for a GAN."""
def __init__(self, generator_loss_fn, discriminator_loss_fn,
generator_optimizer, discriminator_optimizer,
use_loss_summaries=True,
get_hooks_fn=tfgan_train.get_sequential_train_hooks(),
name=None):
"""`Head` for GAN training.
Args:
generator_loss_fn: A TFGAN loss function for the generator. Takes a
`GANModel` and returns a scalar.
discriminator_loss_fn: Same as `generator_loss_fn`, but for the
discriminator.
generator_optimizer: The optimizer for generator updates.
discriminator_optimizer: Same as `generator_optimizer`, but for the
discriminator updates.
use_loss_summaries: If `True`, add loss summaries. If `False`, does not.
If `None`, uses defaults.
get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
of hooks.
name: name of the head. If provided, summary and metrics keys will be
suffixed by `"/" + name`.
"""
# TODO(joelshor): Validate inputs.
if use_loss_summaries in [True, False]:
generator_loss_fn = functools.partial(
generator_loss_fn, add_summaries=use_loss_summaries)
discriminator_loss_fn = functools.partial(
discriminator_loss_fn, add_summaries=use_loss_summaries)
self._generator_loss_fn = generator_loss_fn
self._discriminator_loss_fn = discriminator_loss_fn
self._generator_optimizer = generator_optimizer
self._discriminator_optimizer = discriminator_optimizer
self._get_hooks_fn = get_hooks_fn
@property
def name(self):
return self._name
@property
def logits_dimension(self):
return None
def create_loss(self, features, mode, logits, labels):
"""Returns a GANLoss tuple from the provided GANModel.
See `Head` for more details.
Args:
features: Input `dict` of `Tensor` objects. Unused.
mode: Estimator's `ModeKeys`.
logits: A GANModel tuple.
labels: Must be `None`.
Returns:
A GANLoss tuple.
"""
_validate_logits_and_labels(logits, labels)
del mode, labels, features # unused for this head.
gan_model = logits # rename variable for clarity
return tfgan_tuples.GANLoss(
generator_loss=self._generator_loss_fn(gan_model),
discriminator_loss=self._discriminator_loss_fn(gan_model))
def create_estimator_spec(
self, features, mode, logits, labels=None,
train_op_fn=tfgan_train.gan_train_ops):
"""Returns `EstimatorSpec` that a model_fn can return.
See `Head` for more details.
Args:
features: Must be `None`.
mode: Estimator's `ModeKeys`.
logits: A GANModel tuple.
labels: Must be `None`.
train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer,
and discriminator optimizer, and returns a `GANTrainOps` tuple. For
example, this function can come from TFGAN's `train.py` library, or can
be custom.
Returns:
`EstimatorSpec`.
Raises:
ValueError: If `features` isn't `None`.
ValueError: If `train_op_fn` isn't provided in train mode.
"""
_validate_logits_and_labels(logits, labels)
if features is not None:
raise ValueError('`features` should be `None`. Instead, found: %s' %
features)
gan_model = logits # rename variable for clarity
with ops.name_scope('GANHead'):
if mode == model_fn_lib.ModeKeys.PREDICT:
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.PREDICT,
predictions=gan_model.generated_data)
elif mode == model_fn_lib.ModeKeys.EVAL:
gan_loss = self.create_loss(
features=None, mode=mode, logits=gan_model, labels=None)
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
return model_fn_lib.EstimatorSpec(
mode=model_fn_lib.ModeKeys.EVAL,
predictions=gan_model.generated_data,
loss=scalar_loss,
# TODO(joelshor): Add metrics. If head name provided, append it to
# metric keys.
eval_metric_ops={})
elif mode == model_fn_lib.ModeKeys.TRAIN:
if train_op_fn is None:
raise ValueError('train_op_fn can not be None.')
gan_loss = self.create_loss(None, mode, gan_model, None)
scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss
train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer,
self._discriminator_optimizer)
training_hooks = self._get_hooks_fn(train_ops)
return model_fn_lib.EstimatorSpec(
loss=scalar_loss,
mode=model_fn_lib.ModeKeys.TRAIN,
train_op=train_ops.global_step_inc_op,
training_hooks=training_hooks)
else:
raise ValueError('Mode not recognized: %s' % mode)
def _validate_logits_and_labels(logits, labels):
if labels is not None:
raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must '
'be `None`. Instead, found: %s' % labels)
if not isinstance(logits, tfgan_tuples.GANModel):
raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must '
'be an instnace of a `GANModel`. Instead, found: %s' %
logits)

View File

@ -0,0 +1,85 @@
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Tests for TFGAN's head.py."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples
from tensorflow.contrib.gan.python.estimator.python import head
from tensorflow.python.estimator import model_fn as model_fn_lib
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.platform import test
from tensorflow.python.training import training
def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument
return math_ops.reduce_sum(gan_model.discriminator_real_outputs -
gan_model.discriminator_gen_outputs)
def get_gan_model():
# TODO(joelshor): Find a better way of creating a variable scope.
with variable_scope.variable_scope('generator') as gen_scope:
gen_var = variable_scope.get_variable('dummy_var', initializer=0.0)
with variable_scope.variable_scope('discriminator') as dis_scope:
dis_var = variable_scope.get_variable('dummy_var', initializer=0.0)
return tfgan_tuples.GANModel(
generator_inputs=None,
generated_data=array_ops.ones([3, 4]),
generator_variables=[gen_var],
generator_scope=gen_scope,
generator_fn=None,
real_data=None,
discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var,
discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var,
discriminator_variables=[dis_var],
discriminator_scope=dis_scope,
discriminator_fn=None)
class GANHeadTest(test.TestCase):
def setUp(self):
super(GANHeadTest, self).setUp()
self.gan_head = head.gan_head(
generator_loss_fn=dummy_loss,
discriminator_loss_fn=dummy_loss,
generator_optimizer=training.GradientDescentOptimizer(1.0),
discriminator_optimizer=training.GradientDescentOptimizer(1.0))
self.assertTrue(isinstance(self.gan_head, head.GANHead))
def _test_modes_helper(self, mode):
self.gan_head.create_estimator_spec(
features=None,
mode=mode,
logits=get_gan_model())
def test_modes_predict(self):
self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT)
def test_modes_eval(self):
self._test_modes_helper(model_fn_lib.ModeKeys.EVAL)
def test_modes_train(self):
self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN)
if __name__ == '__main__':
test.main()