diff --git a/tensorflow/contrib/cmake/tf_python.cmake b/tensorflow/contrib/cmake/tf_python.cmake index 3430439d4d0..a19889f3e20 100755 --- a/tensorflow/contrib/cmake/tf_python.cmake +++ b/tensorflow/contrib/cmake/tf_python.cmake @@ -370,6 +370,8 @@ add_python_module("tensorflow/contrib/gan/python/eval") add_python_module("tensorflow/contrib/gan/python/eval/python") add_python_module("tensorflow/contrib/gan/python/features") add_python_module("tensorflow/contrib/gan/python/features/python") +add_python_module("tensorflow/contrib/gan/python/estimator") +add_python_module("tensorflow/contrib/gan/python/estimator/python") add_python_module("tensorflow/contrib/gan/python/losses") add_python_module("tensorflow/contrib/gan/python/losses/python") add_python_module("tensorflow/contrib/graph_editor") diff --git a/tensorflow/contrib/gan/BUILD b/tensorflow/contrib/gan/BUILD index 54dbb11b6eb..64bff7cecf5 100644 --- a/tensorflow/contrib/gan/BUILD +++ b/tensorflow/contrib/gan/BUILD @@ -14,6 +14,7 @@ py_library( ], srcs_version = "PY2AND3", deps = [ + ":estimator", ":eval", ":features", ":losses", @@ -86,6 +87,17 @@ py_library( ], ) +py_library( + name = "estimator", + srcs = ["python/estimator/__init__.py"], + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":head", + "//tensorflow/python:util", + ], +) + py_library( name = "losses", srcs = ["python/losses/__init__.py"], @@ -369,6 +381,89 @@ py_test( ], ) +py_library( + name = "head", + srcs = [ + "python/estimator/python/head.py", + "python/estimator/python/head_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":namedtuples", + ":train", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "head_test", + srcs = ["python/estimator/python/head_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:math_ops", + "//tensorflow/python:training", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_library( + name = "gan_estimator", + srcs = [ + "python/estimator/python/gan_estimator.py", + "python/estimator/python/gan_estimator_impl.py", + ], + srcs_version = "PY2AND3", + deps = [ + ":head", + ":namedtuples", + ":summaries", + ":train", + "//tensorflow/contrib/framework:framework_py", + "//tensorflow/python:framework_ops", + "//tensorflow/python:util", + "//tensorflow/python:variable_scope", + "//tensorflow/python/estimator", + "//tensorflow/python/estimator:model_fn", + ], +) + +py_test( + name = "gan_estimator_test", + srcs = ["python/estimator/python/gan_estimator_test.py"], + shard_count = 1, + srcs_version = "PY2AND3", + deps = [ + ":gan_estimator", + ":namedtuples", + ":tuple_losses", + "//tensorflow/contrib/layers:layers_py", + "//tensorflow/contrib/learn", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:array_ops", + "//tensorflow/python:client_testlib", + "//tensorflow/python:control_flow_ops", + "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", + "//tensorflow/python:parsing_ops", + "//tensorflow/python:summary", + "//tensorflow/python:training", + "//tensorflow/python/estimator:head", + "//tensorflow/python/estimator:model_fn", + "//tensorflow/python/estimator:numpy_io", + "//third_party/py/numpy", + "@six_archive//:six", + ], +) + filegroup( name = "all_files", srcs = glob( diff --git a/tensorflow/contrib/gan/__init__.py b/tensorflow/contrib/gan/__init__.py index 67eee771d04..dff361fdc42 100644 --- a/tensorflow/contrib/gan/__init__.py +++ b/tensorflow/contrib/gan/__init__.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function # Collapse TFGAN into a tiered namespace. +from tensorflow.contrib.gan.python import estimator from tensorflow.contrib.gan.python import eval # pylint:disable=redefined-builtin from tensorflow.contrib.gan.python import features from tensorflow.contrib.gan.python import losses @@ -33,6 +34,7 @@ from tensorflow.contrib.gan.python.train import * from tensorflow.python.util.all_util import remove_undocumented _allowed_symbols = [ + 'estimator', 'eval', 'features', 'losses', diff --git a/tensorflow/contrib/gan/python/estimator/__init__.py b/tensorflow/contrib/gan/python/estimator/__init__.py new file mode 100644 index 00000000000..8c4a1822803 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/__init__.py @@ -0,0 +1,36 @@ +# Copyright 2016 Google Inc. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""TFGAN grouped API. Please see README.md for details and usage.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +# Collapse `estimator` into a single namespace. +# pylint: disable=unused-import,wildcard-import +from tensorflow.contrib.gan.python.estimator.python import gan_estimator +from tensorflow.contrib.gan.python.estimator.python import head + +from tensorflow.contrib.gan.python.estimator.python.gan_estimator import * +from tensorflow.contrib.gan.python.estimator.python.head import * +# pylint: enable=unused-import,wildcard-import + +from tensorflow.python.util.all_util import remove_undocumented + +_allowed_symbols = [ + 'gan_estimator', + 'head', +] + gan_estimator.__all__ + head.__all__ +remove_undocumented(__name__, _allowed_symbols) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py new file mode 100644 index 00000000000..bc0e4854091 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.gan_estimator_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = gan_estimator_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py new file mode 100644 index 00000000000..6e1ee730aac --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_impl.py @@ -0,0 +1,273 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed GAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import enum + +from tensorflow.contrib.framework.python.ops import variables as variable_lib +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.contrib.gan.python.estimator.python import head as head_lib +from tensorflow.contrib.gan.python.eval.python import summaries as tfgan_summaries +from tensorflow.python.estimator import estimator +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.framework import ops +from tensorflow.python.ops import variable_scope + + +__all__ = [ + 'GANEstimator', + 'SummaryType' +] + + +class SummaryType(enum.IntEnum): + NONE = 0 + VARIABLES = 1 + IMAGES = 2 + IMAGE_COMPARISON = 3 + + +_summary_type_map = { + SummaryType.VARIABLES: tfgan_summaries.add_gan_model_summaries, + SummaryType.IMAGES: tfgan_summaries.add_gan_model_image_summaries, + SummaryType.IMAGE_COMPARISON: tfgan_summaries.add_image_comparison_summaries, # pylint:disable=line-too-long +} + + +# TODO(joelshor): For now, this only supports 1:1 generator:discriminator +# training sequentially. Find a nice way to expose options to the user without +# exposing internals. +class GANEstimator(estimator.Estimator): + """An estimator for Generative Adversarial Networks (GANs). + + This Estimator is backed by TFGAN. + + Example: + + ```python + import tensorflow as tf + tfgan = tf.contrib.gan + + # See TFGAN's `train.py` for a description of the generator and + # discriminator API. + def generator_fn(generator_inputs): + ... + return generated_data + + def discriminator_fn(data, conditioning): + ... + return logits + + # Create GAN estimator. + gan_estimator = estimator.GANEstimator( + model_dir, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=tfgan.losses.wasserstein_generator_loss, + discriminator_loss_fn=tfgan.losses.wasserstein_discriminator_loss, + generator_optimizer=tf.train.AdamOptimizier(0.1, 0.5), + discriminator_optimizer=tf.train.AdamOptimizier(0.1, 0.5)) + + # Train estimator. + gan_estimator.train(train_input_fn, steps) + + # Evaluate resulting estimator. + gan_estimator.evaluate(eval_input_fn) + + # Generate samples from generator. + predictions = np.array([ + x for x in gan_estimator.predict(predict_input_fn)]) + ``` + """ + + def __init__(self, + model_dir=None, + generator_fn=None, + discriminator_fn=None, + generator_loss_fn=None, + discriminator_loss_fn=None, + generator_optimizer=None, + discriminator_optimizer=None, + add_summaries=None, + use_loss_summaries=True, + config=None): + """Initializes a GANEstimator instance. + + Args: + model_dir: Directory to save model parameters, graph and etc. This can + also be used to load checkpoints from the directory into a estimator + to continue training a previously saved model. + generator_fn: A python function that takes a Tensor, Tensor list, or + Tensor dictionary as inputs and returns the outputs of the GAN + generator. See `TFGAN` for more details and examples. + discriminator_fn: A python function that takes the output of + `generator_fn` or real data in the GAN setup, and `generator_inputs`. + Outputs a Tensor in the range [-inf, inf]. See `TFGAN` for more details + and examples. + generator_loss_fn: The loss function on the generator. Takes a `GANModel` + tuple. + discriminator_loss_fn: The loss function on the discriminator. Takes a + `GANModel` tuple. + generator_optimizer: The optimizer for generator updates, or a function + that takes no arguments and returns an optimizer. This function will + be called when the default graph is the `GANEstimator`'s graph, so + utilities like `tf.contrib.framework.get_or_create_global_step` will + work. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + config: `RunConfig` object to configure the runtime settings. + """ + # TODO(joelshor): Explicitly validate inputs. + + def _model_fn(features, labels, mode): + gopt = (generator_optimizer() if callable(generator_optimizer) else + generator_optimizer) + dopt = (discriminator_optimizer() if callable(discriminator_optimizer) + else discriminator_optimizer) + gan_head = head_lib.gan_head( + generator_loss_fn, discriminator_loss_fn, gopt, dopt, + use_loss_summaries) + return _gan_model_fn( + features, labels, mode, generator_fn, discriminator_fn, gan_head, + add_summaries) + + super(GANEstimator, self).__init__( + model_fn=_model_fn, model_dir=model_dir, config=config) + + +def _use_check_shapes(real_data): + """Determines whether TFGAN should check Tensor shapes.""" + return isinstance(real_data, ops.Tensor) + + +def _gan_model_fn( + features, + labels, + mode, + generator_fn, + discriminator_fn, + head, + add_summaries=None, + generator_scope_name='Generator'): + """The `model_fn` for the GAN estimator. + + We make the following convention: + features -> TFGAN's `generator_inputs` + labels -> TFGAN's `real_data` + + Args: + features: A dictionary to feed to generator. In the unconditional case, + this might be just `noise`. In the conditional GAN case, this + might be the generator's conditioning. The `generator_fn` determines + what the required keys are. + labels: Real data. Can be any structure, as long as `discriminator_fn` + can accept it for the first argument. + mode: Defines whether this is training, evaluation or prediction. + See `ModeKeys`. + generator_fn: A python lambda that takes `generator_inputs` as inputs and + returns the outputs of the GAN generator. + discriminator_fn: A python lambda that takes `real_data`/`generated data` + and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. + head: A `Head` instance suitable for GANs. + add_summaries: `None`, a single `SummaryType`, or a list of `SummaryType`. + generator_scope_name: The name of the generator scope. We need this to be + the same for GANModels produced by TFGAN's `train.gan_model` and the + manually constructed ones for predictions. + + Returns: + `ModelFnOps` + + Raises: + ValueError: If `labels` isn't `None` during prediction. + """ + real_data = labels + generator_inputs = features + + if mode == model_fn_lib.ModeKeys.TRAIN: + gan_model = _make_train_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope_name, add_summaries) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_model = _make_eval_gan_model( + generator_fn, discriminator_fn, real_data, generator_inputs, + generator_scope_name, add_summaries) + else: + if real_data is not None: + raise ValueError('`labels` must be `None` when mode is `predict`. ' + 'Instead, found %s' % real_data) + gan_model = _make_prediction_gan_model( + generator_inputs, generator_fn, generator_scope_name) + + return head.create_estimator_spec( + features=None, + mode=mode, + logits=gan_model, + labels=None) + + +def _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for training.""" + gan_model = tfgan_train.gan_model( + generator_fn, + discriminator_fn, + real_data, + generator_inputs, + generator_scope=generator_scope, + check_shapes=_use_check_shapes(real_data)) + if add_summaries: + if not isinstance(add_summaries, (tuple, list)): + add_summaries = [add_summaries] + with ops.name_scope(''): + for summary_type in add_summaries: + _summary_type_map[summary_type](gan_model) + + return gan_model + + +def _make_eval_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries): + """Make a `GANModel` for evaluation.""" + return _make_train_gan_model(generator_fn, discriminator_fn, real_data, + generator_inputs, generator_scope, add_summaries) + + +def _make_prediction_gan_model(generator_inputs, generator_fn, generator_scope): + """Make a `GANModel` from just the generator.""" + with variable_scope.variable_scope(generator_scope) as gen_scope: + generator_inputs = tfgan_train._convert_tensor_or_l_or_d(generator_inputs) # pylint:disable=protected-access + generated_data = generator_fn(generator_inputs) + generator_variables = variable_lib.get_trainable_variables(gen_scope) + + return tfgan_tuples.GANModel( + generator_inputs, + generated_data, + generator_variables, + gen_scope, + generator_fn, + real_data=None, + discriminator_real_outputs=None, + discriminator_gen_outputs=None, + discriminator_variables=None, + discriminator_scope=None, + discriminator_fn=None) diff --git a/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py new file mode 100644 index 00000000000..1bfdce9ee94 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/gan_estimator_test.py @@ -0,0 +1,327 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's estimator.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import shutil +import tempfile + +import numpy as np +import six + +from tensorflow.contrib import layers +from tensorflow.contrib.gan.python import namedtuples +from tensorflow.contrib.gan.python.estimator.python import gan_estimator_impl as estimator +from tensorflow.contrib.gan.python.losses.python import tuple_losses as losses +from tensorflow.contrib.learn.python.learn.learn_io import graph_io +from tensorflow.core.example import example_pb2 +from tensorflow.core.example import feature_pb2 +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import head as head_lib +from tensorflow.python.estimator.inputs import numpy_io +from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import control_flow_ops +from tensorflow.python.ops import parsing_ops +from tensorflow.python.platform import test +from tensorflow.python.summary.writer import writer_cache +from tensorflow.python.training import input as input_lib +from tensorflow.python.training import learning_rate_decay +from tensorflow.python.training import monitored_session +from tensorflow.python.training import training +from tensorflow.python.training import training_util + + +def generator_fn(noise_dict): + noise = noise_dict['x'] + return layers.fully_connected(noise, noise.shape[1].value) + + +def discriminator_fn(data, _): + return layers.fully_connected(data, 1) + + +def mock_head(testcase, expected_generator_inputs, expected_real_data, + generator_scope_name): + """Returns a mock head that validates logits values and variable names.""" + discriminator_scope_name = 'Discriminator' # comes from TFGAN defaults + generator_var_names = set([ + '%s/fully_connected/weights:0' % generator_scope_name, + '%s/fully_connected/biases:0' % generator_scope_name]) + discriminator_var_names = set([ + '%s/fully_connected/weights:0' % discriminator_scope_name, + '%s/fully_connected/biases:0' % discriminator_scope_name]) + + def _create_estimator_spec(features, mode, logits, labels): + gan_model = logits # renaming for clarity + is_predict = mode == model_fn_lib.ModeKeys.PREDICT + testcase.assertIsNone(features) + testcase.assertIsNone(labels) + testcase.assertIsInstance(gan_model, namedtuples.GANModel) + + trainable_vars = ops.get_collection(ops.GraphKeys.TRAINABLE_VARIABLES) + expected_var_names = (generator_var_names if is_predict else + generator_var_names | discriminator_var_names) + testcase.assertItemsEqual(expected_var_names, + [var.name for var in trainable_vars]) + + assertions = [] + def _or_none(x): + return None if is_predict else x + testcase.assertEqual(expected_generator_inputs, gan_model.generator_inputs) + # TODO(joelshor): Add check on `generated_data`. + testcase.assertItemsEqual( + generator_var_names, + set([x.name for x in gan_model.generator_variables])) + testcase.assertEqual(generator_scope_name, gan_model.generator_scope.name) + testcase.assertEqual(generator_fn, gan_model.generator_fn) + testcase.assertEqual(_or_none(expected_real_data), gan_model.real_data) + # TODO(joelshor): Add check on `discriminator_real_outputs`. + # TODO(joelshor): Add check on `discriminator_gen_outputs`. + if is_predict: + testcase.assertIsNone(gan_model.discriminator_scope) + else: + testcase.assertEqual(discriminator_scope_name, + gan_model.discriminator_scope.name) + testcase.assertEqual(_or_none(discriminator_fn), gan_model.discriminator_fn) + + with ops.control_dependencies(assertions): + if mode == model_fn_lib.ModeKeys.TRAIN: + return model_fn_lib.EstimatorSpec( + mode=mode, loss=array_ops.zeros([]), + train_op=control_flow_ops.no_op(), training_hooks=[]) + elif mode == model_fn_lib.ModeKeys.EVAL: + return model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data, + loss=array_ops.zeros([])) + elif mode == model_fn_lib.ModeKeys.PREDICT: + return model_fn_lib.EstimatorSpec( + mode=mode, predictions=gan_model.generated_data) + else: + testcase.fail('Invalid mode: {}'.format(mode)) + + head = test.mock.NonCallableMagicMock(spec=head_lib._Head) + head.create_estimator_spec = test.mock.MagicMock( + wraps=_create_estimator_spec) + + return head + + +class GANModelFnTest(test.TestCase): + """Tests that _gan_model_fn passes expected logits to mock head.""" + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_logits_helper(self, mode): + """Tests that the expected logits are passed to mock head.""" + with ops.Graph().as_default(): + training_util.get_or_create_global_step() + generator_inputs = {'x': array_ops.zeros([5, 4])} + real_data = (None if mode == model_fn_lib.ModeKeys.PREDICT else + array_ops.zeros([5, 4])) + generator_scope_name = 'generator' + head = mock_head(self, + expected_generator_inputs=generator_inputs, + expected_real_data=real_data, + generator_scope_name=generator_scope_name) + estimator_spec = estimator._gan_model_fn( + features=generator_inputs, + labels=real_data, + mode=mode, + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_scope_name=generator_scope_name, + head=head) + with monitored_session.MonitoredTrainingSession( + checkpoint_dir=self._model_dir) as sess: + if mode == model_fn_lib.ModeKeys.TRAIN: + sess.run(estimator_spec.train_op) + elif mode == model_fn_lib.ModeKeys.EVAL: + sess.run(estimator_spec.loss) + elif mode == model_fn_lib.ModeKeys.PREDICT: + sess.run(estimator_spec.predictions) + else: + self.fail('Invalid mode: {}'.format(mode)) + + def test_logits_predict(self): + self._test_logits_helper(model_fn_lib.ModeKeys.PREDICT) + + def test_logits_eval(self): + self._test_logits_helper(model_fn_lib.ModeKeys.EVAL) + + def test_logits_train(self): + self._test_logits_helper(model_fn_lib.ModeKeys.TRAIN) + + +# TODO(joelshor): Add pandas test. +class GANEstimatorIntegrationTest(test.TestCase): + + def setUp(self): + self._model_dir = tempfile.mkdtemp() + + def tearDown(self): + if self._model_dir: + writer_cache.FileWriterCache.clear() + shutil.rmtree(self._model_dir) + + def _test_complete_flow( + self, train_input_fn, eval_input_fn, predict_input_fn, prediction_size, + lr_decay=False): + def make_opt(): + gstep = training_util.get_or_create_global_step() + lr = learning_rate_decay.exponential_decay(1.0, gstep, 10, 0.9) + return training.GradientDescentOptimizer(lr) + + gopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + dopt = make_opt if lr_decay else training.GradientDescentOptimizer(1.0) + est = estimator.GANEstimator( + generator_fn=generator_fn, + discriminator_fn=discriminator_fn, + generator_loss_fn=losses.wasserstein_generator_loss, + discriminator_loss_fn=losses.wasserstein_discriminator_loss, + generator_optimizer=gopt, + discriminator_optimizer=dopt, + model_dir=self._model_dir) + + # TRAIN + num_steps = 10 + est.train(train_input_fn, steps=num_steps) + + # EVALUTE + scores = est.evaluate(eval_input_fn) + self.assertEqual(num_steps, scores[ops.GraphKeys.GLOBAL_STEP]) + self.assertIn('loss', six.iterkeys(scores)) + + # PREDICT + predictions = np.array([x for x in est.predict(predict_input_fn)]) + + self.assertAllEqual(prediction_size, predictions.shape) + + def test_numpy_input_fn(self): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + batch_size = 5 + data = np.zeros([batch_size, input_dim]) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, input_dim]) + + def test_numpy_input_fn_lrdecay(self): + """Tests complete flow with numpy_input_fn.""" + input_dim = 4 + batch_size = 5 + data = np.zeros([batch_size, input_dim]) + train_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + num_epochs=None, + shuffle=True) + eval_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + y=data, + batch_size=batch_size, + shuffle=False) + predict_input_fn = numpy_io.numpy_input_fn( + x={'x': data}, + batch_size=batch_size, + shuffle=False) + + self._test_complete_flow( + train_input_fn=train_input_fn, + eval_input_fn=eval_input_fn, + predict_input_fn=predict_input_fn, + prediction_size=[batch_size, input_dim], + lr_decay=True) + + def test_input_fn_from_parse_example(self): + """Tests complete flow with input_fn constructed from parse_example.""" + input_dim = 4 + batch_size = 6 + data = np.zeros([batch_size, input_dim]) + + serialized_examples = [] + for datum in data: + example = example_pb2.Example(features=feature_pb2.Features( + feature={ + 'x': feature_pb2.Feature( + float_list=feature_pb2.FloatList(value=datum)), + 'y': feature_pb2.Feature( + float_list=feature_pb2.FloatList(value=datum)), + })) + serialized_examples.append(example.SerializeToString()) + + feature_spec = { + 'x': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), + 'y': parsing_ops.FixedLenFeature([input_dim], dtypes.float32), + } + def _train_input_fn(): + feature_map = parsing_ops.parse_example( + serialized_examples, feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + def _eval_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + labels = features.pop('y') + return features, labels + def _predict_input_fn(): + feature_map = parsing_ops.parse_example( + input_lib.limit_epochs(serialized_examples, num_epochs=1), + feature_spec) + _, features = graph_io.queue_parsed_features(feature_map) + features.pop('y') + return features, None + + self._test_complete_flow( + train_input_fn=_train_input_fn, + eval_input_fn=_eval_input_fn, + predict_input_fn=_predict_input_fn, + prediction_size=[batch_size, input_dim]) + + +if __name__ == '__main__': + test.main() diff --git a/tensorflow/contrib/gan/python/estimator/python/head.py b/tensorflow/contrib/gan/python/estimator/python/head.py new file mode 100644 index 00000000000..3225d6f41a1 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head.py @@ -0,0 +1,28 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""`tf.Learn` components for `GANEstimator`'s loss.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python.estimator.python import head_impl +# pylint: disable=wildcard-import +from tensorflow.contrib.gan.python.estimator.python.head_impl import * +# pylint: enable=wildcard-import +from tensorflow.python.util.all_util import remove_undocumented + +__all__ = head_impl.__all__ +remove_undocumented(__name__, __all__) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_impl.py b/tensorflow/contrib/gan/python/estimator/python/head_impl.py new file mode 100644 index 00000000000..204c646e194 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head_impl.py @@ -0,0 +1,206 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""A TFGAN-backed GAN Estimator.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import functools + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python import train as tfgan_train +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.estimator.canned import head +from tensorflow.python.framework import ops + +__all__ = [ + 'GANHead', + 'gan_head', +] + + +def gan_head(generator_loss_fn, discriminator_loss_fn, generator_optimizer, + discriminator_optimizer, use_loss_summaries=True, + get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + name=None): + """Creates a `GANHead`. + + Args: + generator_loss_fn: A TFGAN loss function for the generator. Takes a + `GANModel` and returns a scalar. + discriminator_loss_fn: Same as `generator_loss_fn`, but for the + discriminator. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. + + Returns: + An instance of `GANHead`. + """ + return GANHead(generator_loss_fn=generator_loss_fn, + discriminator_loss_fn=discriminator_loss_fn, + generator_optimizer=generator_optimizer, + discriminator_optimizer=discriminator_optimizer, + use_loss_summaries=use_loss_summaries, + get_hooks_fn=get_hooks_fn, + name=name) + + +class GANHead(head._Head): # pylint: disable=protected-access + """`Head` for a GAN.""" + + def __init__(self, generator_loss_fn, discriminator_loss_fn, + generator_optimizer, discriminator_optimizer, + use_loss_summaries=True, + get_hooks_fn=tfgan_train.get_sequential_train_hooks(), + name=None): + """`Head` for GAN training. + + Args: + generator_loss_fn: A TFGAN loss function for the generator. Takes a + `GANModel` and returns a scalar. + discriminator_loss_fn: Same as `generator_loss_fn`, but for the + discriminator. + generator_optimizer: The optimizer for generator updates. + discriminator_optimizer: Same as `generator_optimizer`, but for the + discriminator updates. + use_loss_summaries: If `True`, add loss summaries. If `False`, does not. + If `None`, uses defaults. + get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list + of hooks. + name: name of the head. If provided, summary and metrics keys will be + suffixed by `"/" + name`. + """ + # TODO(joelshor): Validate inputs. + + if use_loss_summaries in [True, False]: + generator_loss_fn = functools.partial( + generator_loss_fn, add_summaries=use_loss_summaries) + discriminator_loss_fn = functools.partial( + discriminator_loss_fn, add_summaries=use_loss_summaries) + self._generator_loss_fn = generator_loss_fn + self._discriminator_loss_fn = discriminator_loss_fn + self._generator_optimizer = generator_optimizer + self._discriminator_optimizer = discriminator_optimizer + self._get_hooks_fn = get_hooks_fn + + @property + def name(self): + return self._name + + @property + def logits_dimension(self): + return None + + def create_loss(self, features, mode, logits, labels): + """Returns a GANLoss tuple from the provided GANModel. + + See `Head` for more details. + + Args: + features: Input `dict` of `Tensor` objects. Unused. + mode: Estimator's `ModeKeys`. + logits: A GANModel tuple. + labels: Must be `None`. + + Returns: + A GANLoss tuple. + + """ + _validate_logits_and_labels(logits, labels) + del mode, labels, features # unused for this head. + gan_model = logits # rename variable for clarity + return tfgan_tuples.GANLoss( + generator_loss=self._generator_loss_fn(gan_model), + discriminator_loss=self._discriminator_loss_fn(gan_model)) + + def create_estimator_spec( + self, features, mode, logits, labels=None, + train_op_fn=tfgan_train.gan_train_ops): + """Returns `EstimatorSpec` that a model_fn can return. + + See `Head` for more details. + + Args: + features: Must be `None`. + mode: Estimator's `ModeKeys`. + logits: A GANModel tuple. + labels: Must be `None`. + train_op_fn: Function that takes a GANModel, GANLoss, generator optimizer, + and discriminator optimizer, and returns a `GANTrainOps` tuple. For + example, this function can come from TFGAN's `train.py` library, or can + be custom. + + Returns: + `EstimatorSpec`. + + Raises: + ValueError: If `features` isn't `None`. + ValueError: If `train_op_fn` isn't provided in train mode. + """ + _validate_logits_and_labels(logits, labels) + if features is not None: + raise ValueError('`features` should be `None`. Instead, found: %s' % + features) + gan_model = logits # rename variable for clarity + with ops.name_scope('GANHead'): + if mode == model_fn_lib.ModeKeys.PREDICT: + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.PREDICT, + predictions=gan_model.generated_data) + elif mode == model_fn_lib.ModeKeys.EVAL: + gan_loss = self.create_loss( + features=None, mode=mode, logits=gan_model, labels=None) + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + return model_fn_lib.EstimatorSpec( + mode=model_fn_lib.ModeKeys.EVAL, + predictions=gan_model.generated_data, + loss=scalar_loss, + # TODO(joelshor): Add metrics. If head name provided, append it to + # metric keys. + eval_metric_ops={}) + elif mode == model_fn_lib.ModeKeys.TRAIN: + if train_op_fn is None: + raise ValueError('train_op_fn can not be None.') + gan_loss = self.create_loss(None, mode, gan_model, None) + scalar_loss = gan_loss.generator_loss + gan_loss.discriminator_loss + train_ops = train_op_fn(gan_model, gan_loss, self._generator_optimizer, + self._discriminator_optimizer) + training_hooks = self._get_hooks_fn(train_ops) + return model_fn_lib.EstimatorSpec( + loss=scalar_loss, + mode=model_fn_lib.ModeKeys.TRAIN, + train_op=train_ops.global_step_inc_op, + training_hooks=training_hooks) + else: + raise ValueError('Mode not recognized: %s' % mode) + + +def _validate_logits_and_labels(logits, labels): + if labels is not None: + raise ValueError('`GANHead`\'s `create_estimator_spec` input `labels` must ' + 'be `None`. Instead, found: %s' % labels) + + if not isinstance(logits, tfgan_tuples.GANModel): + raise ValueError('`GANHead`\'s `create_estimator_spec` input `logits` must ' + 'be an instnace of a `GANModel`. Instead, found: %s' % + logits) diff --git a/tensorflow/contrib/gan/python/estimator/python/head_test.py b/tensorflow/contrib/gan/python/estimator/python/head_test.py new file mode 100644 index 00000000000..8168f005cd1 --- /dev/null +++ b/tensorflow/contrib/gan/python/estimator/python/head_test.py @@ -0,0 +1,85 @@ +# Copyright 2017 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== +"""Tests for TFGAN's head.py.""" + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +from tensorflow.contrib.gan.python import namedtuples as tfgan_tuples +from tensorflow.contrib.gan.python.estimator.python import head + +from tensorflow.python.estimator import model_fn as model_fn_lib +from tensorflow.python.ops import array_ops +from tensorflow.python.ops import math_ops +from tensorflow.python.ops import variable_scope +from tensorflow.python.platform import test +from tensorflow.python.training import training + + +def dummy_loss(gan_model, add_summaries=True): # pylint:disable=unused-argument + return math_ops.reduce_sum(gan_model.discriminator_real_outputs - + gan_model.discriminator_gen_outputs) + + +def get_gan_model(): + # TODO(joelshor): Find a better way of creating a variable scope. + with variable_scope.variable_scope('generator') as gen_scope: + gen_var = variable_scope.get_variable('dummy_var', initializer=0.0) + with variable_scope.variable_scope('discriminator') as dis_scope: + dis_var = variable_scope.get_variable('dummy_var', initializer=0.0) + return tfgan_tuples.GANModel( + generator_inputs=None, + generated_data=array_ops.ones([3, 4]), + generator_variables=[gen_var], + generator_scope=gen_scope, + generator_fn=None, + real_data=None, + discriminator_real_outputs=array_ops.ones([1, 2, 3]) * dis_var, + discriminator_gen_outputs=array_ops.ones([1, 2, 3]) * gen_var * dis_var, + discriminator_variables=[dis_var], + discriminator_scope=dis_scope, + discriminator_fn=None) + + +class GANHeadTest(test.TestCase): + + def setUp(self): + super(GANHeadTest, self).setUp() + self.gan_head = head.gan_head( + generator_loss_fn=dummy_loss, + discriminator_loss_fn=dummy_loss, + generator_optimizer=training.GradientDescentOptimizer(1.0), + discriminator_optimizer=training.GradientDescentOptimizer(1.0)) + self.assertTrue(isinstance(self.gan_head, head.GANHead)) + + def _test_modes_helper(self, mode): + self.gan_head.create_estimator_spec( + features=None, + mode=mode, + logits=get_gan_model()) + + def test_modes_predict(self): + self._test_modes_helper(model_fn_lib.ModeKeys.PREDICT) + + def test_modes_eval(self): + self._test_modes_helper(model_fn_lib.ModeKeys.EVAL) + + def test_modes_train(self): + self._test_modes_helper(model_fn_lib.ModeKeys.TRAIN) + + +if __name__ == '__main__': + test.main()