From 47b04fdb3ee2491dfb0306ea56617d0e95b38c27 Mon Sep 17 00:00:00 2001 From: Rohan Jain Date: Sat, 13 Oct 2018 17:42:53 -0700 Subject: [PATCH] Rolling forward changing distribution strategies to use MultiDeviceIterator. The underlying issue with NaN's has now been resolved. PiperOrigin-RevId: 217014692 --- tensorflow/contrib/distribute/python/BUILD | 28 +-- .../distribute/python/metrics_v1_test.py | 3 +- .../distribute/python/minimize_loss_test.py | 26 +- .../python/mirrored_strategy_multigpu_test.py | 12 +- .../contrib/distribute/python/monitor.py | 1 + .../distribute/python/optimizer_v2_test.py | 8 +- .../distribute/python/prefetching_ops_v2.py | 232 ------------------ .../python/prefetching_ops_v2_test.py | 90 ------- .../contrib/distribute/python/step_fn.py | 7 +- .../contrib/distribute/python/step_fn_test.py | 1 + .../contrib/distribute/python/values.py | 40 ++- .../contrib/distribute/python/values_test.py | 23 +- .../data/ops/multi_device_iterator_ops.py | 12 + 13 files changed, 90 insertions(+), 393 deletions(-) delete mode 100644 tensorflow/contrib/distribute/python/prefetching_ops_v2.py delete mode 100644 tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py diff --git a/tensorflow/contrib/distribute/python/BUILD b/tensorflow/contrib/distribute/python/BUILD index 76d5b59ce17..dc2964568b5 100644 --- a/tensorflow/contrib/distribute/python/BUILD +++ b/tensorflow/contrib/distribute/python/BUILD @@ -22,7 +22,6 @@ py_library( visibility = ["//tensorflow:internal"], deps = [ ":input_ops", - ":prefetching_ops_v2", "//tensorflow/python:array_ops", "//tensorflow/python:control_flow_ops", "//tensorflow/python:device_util", @@ -31,6 +30,7 @@ py_library( "//tensorflow/python:resource_variable_ops", "//tensorflow/python:training", "//tensorflow/python:util", + "//tensorflow/python/data/ops:multi_device_iterator_ops", "//tensorflow/python/eager:context", "//tensorflow/python/training/checkpointable:base", "@six_archive//:six", @@ -666,32 +666,6 @@ cuda_py_test( ], ) -py_library( - name = "prefetching_ops_v2", - srcs = ["prefetching_ops_v2.py"], - deps = [ - "//tensorflow/contrib/data/python/ops:prefetching_ops", - "//tensorflow/python:experimental_dataset_ops_gen", - "//tensorflow/python:framework_ops", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/util:nest", - "//tensorflow/python/data/util:sparse", - ], -) - -cuda_py_test( - name = "prefetching_ops_v2_test", - srcs = ["prefetching_ops_v2_test.py"], - additional_deps = [ - ":prefetching_ops_v2", - "//tensorflow/python:client_testlib", - "//tensorflow/python:framework_ops", - "//tensorflow/python:framework_test_lib", - "//tensorflow/python/data/ops:dataset_ops", - "//tensorflow/python/data/ops:iterator_ops", - ], -) - py_library( name = "input_ops", srcs = ["input_ops.py"], diff --git a/tensorflow/contrib/distribute/python/metrics_v1_test.py b/tensorflow/contrib/distribute/python/metrics_v1_test.py index ae4189eb1cb..2c79a8bfd3c 100644 --- a/tensorflow/contrib/distribute/python/metrics_v1_test.py +++ b/tensorflow/contrib/distribute/python/metrics_v1_test.py @@ -96,7 +96,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): with ops.Graph().as_default(), distribution.scope(): iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + dataset_fn).make_initializable_iterator() if isinstance(distribution, tpu_strategy.TPUStrategy): def step_fn(ctx, inputs): value, update = distribution.call_for_each_tower( @@ -120,6 +120,7 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): # replace "distribution.num_towers" with "1". batches_per_update = distribution.num_towers + self.evaluate(iterator.initializer) self.evaluate(distribution.initialize()) self.evaluate(variables.local_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/minimize_loss_test.py b/tensorflow/contrib/distribute/python/minimize_loss_test.py index 60e134055ff..3c4544a39ef 100644 --- a/tensorflow/contrib/distribute/python/minimize_loss_test.py +++ b/tensorflow/contrib/distribute/python/minimize_loss_test.py @@ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): + def _get_iterator(self, ds): + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() + self.evaluate(iterator.initializer) + return iterator + @combinations.generate( combinations.times( combinations.distributions_and_v1_optimizers(), @@ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.group( @@ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, *inputs, run_concurrently=layer.built)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -239,8 +244,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) return control_flow_ops.group(fetches) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -333,8 +337,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): distribution.call_for_each_tower( model_fn, x, y, run_concurrently=False)) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): return distribution.run_steps_on_dataset( @@ -427,8 +430,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): output=loss) return distribution.group(train_op) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) def run_step(): initial_loss = lambda: constant_op.constant(1e7) diff --git a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py index ed36639ce86..fd833c772d4 100644 --- a/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py +++ b/tensorflow/contrib/distribute/python/mirrored_strategy_multigpu_test.py @@ -307,9 +307,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): dist = mirrored_strategy.MirroredStrategy( ["/device:GPU:0", "/device:CPU:0"]) - features = dist.distribute_dataset( - lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) - ).make_one_shot_iterator().get_next() + ds = dist.distribute_dataset( + lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() + self.evaluate([iterator.initializer]) + + features = iterator.get_next() with dist.scope(): result = dist.call_for_each_tower( diff --git a/tensorflow/contrib/distribute/python/monitor.py b/tensorflow/contrib/distribute/python/monitor.py index 7644acedc99..17b7ab74f63 100644 --- a/tensorflow/contrib/distribute/python/monitor.py +++ b/tensorflow/contrib/distribute/python/monitor.py @@ -51,6 +51,7 @@ class Monitor(object): else: if session is None: raise ValueError("Should provide a `session` in Graph mode.") + session.run(step_callable._iterator.initializer) # pylint: disable=protected-access self._run_step = session.make_callable(step_callable()) session.run(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/optimizer_v2_test.py b/tensorflow/contrib/distribute/python/optimizer_v2_test.py index 6e9ba37a198..30644331298 100644 --- a/tensorflow/contrib/distribute/python/optimizer_v2_test.py +++ b/tensorflow/contrib/distribute/python/optimizer_v2_test.py @@ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): model_fn, dataset_fn, layer = minimize_loss_example( optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) - iterator = distribution.distribute_dataset( - dataset_fn).make_one_shot_iterator() + ds = distribution.distribute_dataset(dataset_fn) + if context.executing_eagerly(): + iterator = ds.make_one_shot_iterator() + else: + iterator = ds.make_initializable_iterator() def run_step(): return control_flow_ops.group(distribution.unwrap( @@ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): if not context.executing_eagerly(): with self.cached_session() as sess: + sess.run(iterator.initializer) run_step = sess.make_callable(run_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2.py deleted file mode 100644 index d48aa9c89bc..00000000000 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2.py +++ /dev/null @@ -1,232 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Extension of prefetching_ops to support more than one device.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -import warnings - -from tensorflow.python.data.experimental.ops import prefetching_ops -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.data.ops import iterator_ops -from tensorflow.python.data.util import nest as data_nest -from tensorflow.python.data.util import sparse -from tensorflow.python.eager import context -from tensorflow.python.framework import dtypes -from tensorflow.python.framework import function -from tensorflow.python.framework import ops -from tensorflow.python.ops import gen_experimental_dataset_ops as ged_ops -from tensorflow.python.util import nest - - -# pylint: disable=protected-access -class _PrefetchToDeviceIterator(object): - """A replacement for `tf.data.Iterator` that prefetches to another device. - - Args: - input_dataset: The input dataset. - one_shot: If true, we make a one shot iterator that's already initialized. - devices: Devices on which to prefetch. - buffer_size: Size of the prefetching buffer. - shared_name: (Optional.) If non-empty, the returned iterator will be shared - under the given name across multiple sessions that share the same devices - (e.g. when using a remote server). Only used if one_shot is False. - - Returns: - An Iterator type object. - """ - - def __init__(self, - input_dataset, - one_shot, - devices, - buffer_size, - shared_name=None): - self._input_dataset = input_dataset - self._get_next_call_count = 0 - self._one_shot = one_shot - if shared_name is None: - shared_name = "" - self._devices = devices - - if self._one_shot: - self._input_iterator = input_dataset.make_one_shot_iterator() - else: - self._input_iterator = iterator_ops.Iterator.from_structure( - self._input_dataset.output_types, self._input_dataset.output_shapes, - shared_name, self._input_dataset.output_classes) - input_iterator_handle = self._input_iterator.string_handle() - - @function.Defun(dtypes.string) - def _prefetch_fn(handle): - """Prefetches one element from `input_iterator`.""" - remote_iterator = iterator_ops.Iterator.from_string_handle( - handle, self._input_iterator.output_types, - self._input_iterator.output_shapes, - self._input_iterator.output_classes) - ret = remote_iterator.get_next() - return nest.flatten(sparse.serialize_sparse_tensors(ret)) - - target_device = ged_ops.experimental_iterator_get_device( - self._input_iterator._iterator_resource) - self._buffering_resources = [] - for device in nest.flatten(self._devices): - with ops.device(device): - buffer_resource_handle = prefetching_ops.function_buffering_resource( - f=_prefetch_fn, - output_types=data_nest.flatten( - sparse.as_dense_types(self._input_dataset.output_types, - self._input_dataset.output_classes)), - target_device=target_device, - string_arg=input_iterator_handle, - buffer_size=buffer_size, - shared_name=shared_name) - self._buffering_resources.append(buffer_resource_handle) - - if not self._one_shot: - reset_ops = [] - for buffer_resource in self._buffering_resources: - reset_ops.append( - ged_ops.experimental_function_buffering_resource_reset( - buffer_resource)) - with ops.control_dependencies(reset_ops): - self._initializer = self._input_iterator.make_initializer( - self._input_dataset) - - def get_next(self, name=None): - """See `tf.data.Iterator.get_next`.""" - self._get_next_call_count += 1 - if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: - warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) - - flat_result = [] - # TODO(priyag): This will fail if the input size (typically number of - # batches) is not divisible by number of devices. - # How do we handle that more gracefully / let the user know? - for buffer_resource in self._buffering_resources: - flat_ret = ged_ops.experimental_function_buffering_resource_get_next( - buffer_resource, - output_types=data_nest.flatten( - sparse.as_dense_types(self.output_types, self.output_classes)), - name=name) - - ret = sparse.deserialize_sparse_tensors( - data_nest.pack_sequence_as(self.output_types, flat_ret), - self.output_types, self.output_shapes, self.output_classes) - - for tensor, shape in zip( - data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): - if isinstance(tensor, ops.Tensor): - tensor.set_shape(shape) - flat_result.append(ret) - - return nest.pack_sequence_as(self._devices, flat_result) - - @property - def initializer(self): - if self._one_shot: - raise NotImplementedError("Can't initialize a one_shot_iterator") - return self._initializer - - @property - def output_classes(self): - return self._input_dataset.output_classes - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_types(self): - return self._input_dataset.output_types - - -# pylint: enable=protected-access - - -class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): - """A `Dataset` whose iterator prefetches elements to other device(s).""" - - def __init__(self, input_dataset, devices, buffer_size): - super(_PrefetchToDeviceDataset, self).__init__(input_dataset) - self._input_dataset = input_dataset - self._devices = devices - self._buffer_size = buffer_size if buffer_size is not None else 1 - - def make_one_shot_iterator(self): - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=True, - devices=self._devices, - buffer_size=self._buffer_size) - - def make_initializable_iterator(self, shared_name=None): - if context.executing_eagerly(): - raise RuntimeError( - "make_initializable_iterator is not supported when eager " - "execution is enabled.") - - return _PrefetchToDeviceIterator( - self._input_dataset, - one_shot=False, - devices=self._devices, - buffer_size=self._buffer_size, - shared_name=shared_name) - - def _as_variant_tensor(self): - # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset - # transformation methods is called. - # TODO(mrry): Investigate support for chaining further transformations after - # the prefetch, including GPU support. - raise NotImplementedError("`prefetch_to_devices()` must be the last " - "transformation in a dataset pipeline.") - - # TODO(priyag): Fix the output types, shapes and classes to match the result - # of get_next (which has the additional nesting layer of devices now). - @property - def output_types(self): - return self._input_dataset.output_types - - @property - def output_shapes(self): - return self._input_dataset.output_shapes - - @property - def output_classes(self): - return self._input_dataset.output_classes - - -def prefetch_to_devices(devices, buffer_size=None): - """A transformation that prefetches dataset values to the given `devices`. - - NOTE: Although the transformation creates a `tf.data.Dataset`, the - transformation must be the final `Dataset` in the input pipeline. - - Args: - devices: A nested structure of devices on which to prefetch the data. It can - be a single device name, or a tuple or list of device names. - buffer_size: (Optional.) The number of elements to buffer on each device. - Defaults to an automatically chosen value. - - Returns: - A `Dataset` transformation function, which can be passed to - `tf.data.Dataset.apply`. - """ - - def _apply_fn(dataset): - return _PrefetchToDeviceDataset(dataset, devices, buffer_size) - - return _apply_fn diff --git a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py b/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py deleted file mode 100644 index 16799104e81..00000000000 --- a/tensorflow/contrib/distribute/python/prefetching_ops_v2_test.py +++ /dev/null @@ -1,90 +0,0 @@ -# Copyright 2017 The TensorFlow Authors. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== -"""Tests for prefetching_ops_v2.""" -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function - -from tensorflow.contrib.distribute.python import prefetching_ops_v2 -from tensorflow.python.data.ops import dataset_ops -from tensorflow.python.framework import errors -from tensorflow.python.framework import test_util -from tensorflow.python.platform import test - - -class PrefetchingOpsV2Test(test.TestCase): - - def testPrefetchToOneDevice(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices("/gpu:0")) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - with self.cached_session() as sess: - for i in range(10): - self.assertEqual(i, sess.run(next_element)) - with self.assertRaises(errors.OutOfRangeError): - sess.run(next_element) - - def testPrefetchToTwoDevicesInAList(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) - - iterator = device_dataset.make_one_shot_iterator() - next_element = iterator.get_next() - - output = [] - # TODO(rohanj): Modify test to go till the end of the dataset when we - # switch to MultiDeviceIterator. - with self.cached_session() as sess: - for _ in range(4): - result = sess.run(next_element) - self.assertEqual(2, len(result)) - output.extend(result) - self.assertEquals(set(range(8)), set(output)) - - def testPrefetchToTwoDevicesWithReinit(self): - if not test_util.is_gpu_available(): - self.skipTest("No GPU available") - - host_dataset = dataset_ops.Dataset.range(10) - device_dataset = host_dataset.apply( - prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) - - iterator = device_dataset.make_initializable_iterator() - next_element = iterator.get_next() - - # TODO(rohanj): Modify test to go till the end of the dataset when we - # switch to MultiDeviceIterator. - with self.cached_session() as sess: - sess.run(iterator.initializer) - for _ in range(4): - sess.run(next_element) - sess.run(iterator.initializer) - for _ in range(4): - sess.run(next_element) - - -if __name__ == "__main__": - test.main() diff --git a/tensorflow/contrib/distribute/python/step_fn.py b/tensorflow/contrib/distribute/python/step_fn.py index 1b5a4f64e5b..23bf36184fa 100644 --- a/tensorflow/contrib/distribute/python/step_fn.py +++ b/tensorflow/contrib/distribute/python/step_fn.py @@ -19,6 +19,7 @@ from __future__ import division from __future__ import print_function from tensorflow.python.eager import backprop +from tensorflow.python.eager import context from tensorflow.python.training import optimizer as optimizer_lib @@ -50,7 +51,11 @@ class StandardInputStep(Step): def __init__(self, dataset_fn, distribution): super(StandardInputStep, self).__init__(distribution) self._distributed_input = distribution.distribute_dataset(dataset_fn) - self._iterator = self._distributed_input.make_one_shot_iterator() + if context.executing_eagerly(): + self._iterator = self._distributed_input.make_one_shot_iterator() + else: + # TODO(priyag): Expose initializer via some initializer property. + self._iterator = self._distributed_input.make_initializable_iterator() class StandardSingleLossStep(StandardInputStep): diff --git a/tensorflow/contrib/distribute/python/step_fn_test.py b/tensorflow/contrib/distribute/python/step_fn_test.py index f1ada49fa37..1ff9b9ceec1 100644 --- a/tensorflow/contrib/distribute/python/step_fn_test.py +++ b/tensorflow/contrib/distribute/python/step_fn_test.py @@ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): run_step = single_loss_step else: with self.cached_session() as sess: + sess.run(single_loss_step._iterator.initializer) run_step = sess.make_callable(single_loss_step()) self.evaluate(variables.global_variables_initializer()) diff --git a/tensorflow/contrib/distribute/python/values.py b/tensorflow/contrib/distribute/python/values.py index 472cb4230c5..c555dc8a71d 100644 --- a/tensorflow/contrib/distribute/python/values.py +++ b/tensorflow/contrib/distribute/python/values.py @@ -27,7 +27,7 @@ import weakref import six from tensorflow.contrib.distribute.python import input_ops -from tensorflow.contrib.distribute.python import prefetching_ops_v2 +from tensorflow.python.data.ops import multi_device_iterator_ops from tensorflow.python.eager import context from tensorflow.python.eager import tape from tensorflow.python.framework import device as tf_device @@ -1089,7 +1089,7 @@ class PerDeviceDataIterator(object): def get_next(self, name=None): """Scatter the input across devices.""" if self._prefetch_on_device: - data_list = self._iterator.get_next(name=name) + data_list = self._iterator.get_next() index = dict(zip(self._devices, data_list)) else: batch = self._iterator.get_next(name=name) @@ -1113,17 +1113,15 @@ class PerDeviceDataset(object): self._devices = devices # Default to using prefetching in graph mode, unless specified. - # TODO(priyag): Enable prefetching in eager mode. + # TODO(rohanj): Enable prefetching in eager mode. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: self._prefetch_on_device = not context.executing_eagerly() assert not (self._prefetch_on_device and context.executing_eagerly()), ( "Prefetching is only supported in graph mode currently") - if self._prefetch_on_device: - self._dataset = dataset.apply( - prefetching_ops_v2.prefetch_to_devices(self._devices)) - else: + self._dataset = dataset + if not self._prefetch_on_device: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. @@ -1131,15 +1129,33 @@ class PerDeviceDataset(object): def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerDeviceDataset.""" + # Graph mode with one shot iterator is disabled. + if not context.executing_eagerly(): + raise ValueError("Cannot create a one shot iterator. Please use " + "`make_initializable_iterator()` instead.") + # Eager mode prefetching would error out in constructor. Only remaining + # case is non-prefetching in eager mode. We delegate to + # PerDeviceDataIterator to handle that case. dataset_iterator = self._dataset.make_one_shot_iterator() - return PerDeviceDataIterator(dataset_iterator, self._devices, - self._prefetch_on_device) + return PerDeviceDataIterator( + dataset_iterator, self._devices, prefetch_on_device=False) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerDeviceDataset.""" - dataset_iterator = self._dataset.make_initializable_iterator() - return PerDeviceDataIterator(dataset_iterator, self._devices, - self._prefetch_on_device) + # Eager mode generates already initialized iterators. Hence we cannot create + # an initializable iterator. + if context.executing_eagerly(): + raise ValueError("Cannot create initializable iterator in Eager mode. " + "Please use `make_one_shot_iterator` instead.") + if self._prefetch_on_device: + dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._devices) + else: + dataset_iterator = self._dataset.make_initializable_iterator() + return PerDeviceDataIterator( + dataset_iterator, + self._devices, + prefetch_on_device=self._prefetch_on_device) class MultiWorkerDataIterator(object): diff --git a/tensorflow/contrib/distribute/python/values_test.py b/tensorflow/contrib/distribute/python/values_test.py index 121d2fbb3fb..7ef4776ac6d 100644 --- a/tensorflow/contrib/distribute/python/values_test.py +++ b/tensorflow/contrib/distribute/python/values_test.py @@ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase): def _test_iterator_no_prefetch(self, devices, dataset, expected_values): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=False) - iterator = per_device_dataset.make_one_shot_iterator() + if context.executing_eagerly(): + iterator = per_device_dataset.make_one_shot_iterator() + else: + iterator = per_device_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) for expected_value in expected_values: next_element = iterator.get_next() @@ -366,21 +370,14 @@ class PerDeviceDatasetTest(test.TestCase): if not context.executing_eagerly(): per_device_dataset = values.PerDeviceDataset( dataset, devices, prefetch_on_device=True) - iterator = per_device_dataset.make_one_shot_iterator() + iterator = per_device_dataset.make_initializable_iterator() + self.evaluate([iterator.initializer]) - # With prefetching, we cannot guarantee which input ends up on which - # device, so we verify that the complete set seen on all devices is - # correct, and equal numbers are distributed to each device. - combined_actual = [] - combined_expected = [] for expected_value in expected_values: next_element = iterator.get_next() - combined_actual.extend( - self.evaluate( - [values.select_device(d, next_element) for d in devices])) - combined_expected.extend(expected_value) - - self.assertEqual(set(combined_expected), set(combined_actual)) + computed_value = self.evaluate( + [values.select_device(d, next_element) for d in devices]) + self.assertEqual(expected_value, computed_value) with self.assertRaises(errors.OutOfRangeError): next_element = iterator.get_next() diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index 2086614b7ce..b7033cc4ceb 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -229,3 +229,15 @@ class MultiDeviceIterator(object): @property def initializer(self): return self._initializer + + @property + def output_types(self): + return self._dataset.output_types + + @property + def output_shapes(self): + return self._dataset.output_shapes + + @property + def output_classes(self): + return self._dataset.output_classes