Switching Distribution strategies to use MultiDeviceIterator. Currently only supported in Graph mode using initializable iterators. In a subsequent change, we'll add in support for Eager mode as well.
This removes prefetching_ops_v2 code. PiperOrigin-RevId: 214546754
This commit is contained in:
		
							parent
							
								
									3f4b8c1381
								
							
						
					
					
						commit
						7f1d70d97f
					
				| @ -22,7 +22,6 @@ py_library( | ||||
|     visibility = ["//tensorflow:internal"], | ||||
|     deps = [ | ||||
|         ":input_ops", | ||||
|         ":prefetching_ops_v2", | ||||
|         "//tensorflow/python:array_ops", | ||||
|         "//tensorflow/python:control_flow_ops", | ||||
|         "//tensorflow/python:device_util", | ||||
| @ -30,6 +29,7 @@ py_library( | ||||
|         "//tensorflow/python:framework_ops", | ||||
|         "//tensorflow/python:training", | ||||
|         "//tensorflow/python:util", | ||||
|         "//tensorflow/python/data/ops:multi_device_iterator_ops", | ||||
|         "//tensorflow/python/eager:context", | ||||
|         "//tensorflow/python/training/checkpointable:base", | ||||
|         "@six_archive//:six", | ||||
| @ -647,32 +647,6 @@ cuda_py_test( | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "prefetching_ops_v2", | ||||
|     srcs = ["prefetching_ops_v2.py"], | ||||
|     deps = [ | ||||
|         "//tensorflow/contrib/data/python/ops:contrib_op_loader", | ||||
|         "//tensorflow/contrib/data/python/ops:prefetching_ops", | ||||
|         "//tensorflow/python:framework_ops", | ||||
|         "//tensorflow/python/data/ops:dataset_ops", | ||||
|         "//tensorflow/python/data/util:nest", | ||||
|         "//tensorflow/python/data/util:sparse", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| cuda_py_test( | ||||
|     name = "prefetching_ops_v2_test", | ||||
|     srcs = ["prefetching_ops_v2_test.py"], | ||||
|     additional_deps = [ | ||||
|         ":prefetching_ops_v2", | ||||
|         "//tensorflow/python:client_testlib", | ||||
|         "//tensorflow/python:framework_ops", | ||||
|         "//tensorflow/python:framework_test_lib", | ||||
|         "//tensorflow/python/data/ops:dataset_ops", | ||||
|         "//tensorflow/python/data/ops:iterator_ops", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| py_library( | ||||
|     name = "input_ops", | ||||
|     srcs = ["input_ops.py"], | ||||
|  | ||||
| @ -86,10 +86,11 @@ class MetricsV1Test(test.TestCase, parameterized.TestCase): | ||||
|   def _test_metric(self, distribution, dataset_fn, metric_fn, expected_fn): | ||||
|     with ops.Graph().as_default(), distribution.scope(): | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|           dataset_fn).make_initializable_iterator() | ||||
|       value, update = distribution.call_for_each_tower( | ||||
|           metric_fn, iterator.get_next()) | ||||
|       update = distribution.group(update) | ||||
|       self.evaluate(iterator.initializer) | ||||
|       self.evaluate(variables.local_variables_initializer()) | ||||
|       # TODO(josh11b): Once we switch to using a global batch size for input, | ||||
|       # replace "distribution.num_towers" with "1". | ||||
|  | ||||
| @ -41,6 +41,14 @@ from tensorflow.python.ops.losses import losses_impl | ||||
| 
 | ||||
| class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
| 
 | ||||
|   def _get_iterator(self, ds): | ||||
|     if context.executing_eagerly(): | ||||
|       iterator = ds.make_one_shot_iterator() | ||||
|     else: | ||||
|       iterator = ds.make_initializable_iterator() | ||||
|       self.evaluate(iterator.initializer) | ||||
|     return iterator | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.times( | ||||
|           combinations.distributions_and_v1_optimizers(), | ||||
| @ -62,8 +70,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|             distribution.call_for_each_tower( | ||||
|                 model_fn, *inputs, run_concurrently=layer.built)) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) | ||||
| 
 | ||||
|       def run_step(): | ||||
|         return distribution.run_steps_on_dataset( | ||||
| @ -99,8 +106,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|       model_fn, dataset_fn, layer = minimize_loss_example( | ||||
|           optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) | ||||
| 
 | ||||
|       def run_step(): | ||||
|         return distribution.group( | ||||
| @ -159,8 +165,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|             distribution.call_for_each_tower( | ||||
|                 model_fn, *inputs, run_concurrently=layer.built)) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) | ||||
| 
 | ||||
|       def run_step(): | ||||
|         return distribution.run_steps_on_dataset( | ||||
| @ -244,8 +249,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|           fetches += ops.get_collection(ops.GraphKeys.UPDATE_OPS) | ||||
|         return control_flow_ops.group(fetches) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) | ||||
| 
 | ||||
|       def run_step(): | ||||
|         return distribution.run_steps_on_dataset( | ||||
| @ -338,8 +342,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|             distribution.call_for_each_tower( | ||||
|                 model_fn, x, y, run_concurrently=False)) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) | ||||
| 
 | ||||
|       def run_step(): | ||||
|         return distribution.run_steps_on_dataset( | ||||
| @ -432,8 +435,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|             output=loss) | ||||
|         return distribution.group(train_op) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       iterator = self._get_iterator(distribution.distribute_dataset(dataset_fn)) | ||||
| 
 | ||||
|       def run_step(): | ||||
|         initial_loss = lambda: constant_op.constant(1e7) | ||||
|  | ||||
| @ -480,8 +480,10 @@ class MirroredStrategy(distribute_lib.DistributionStrategy): | ||||
|           self._prefetch_on_device) | ||||
|     else: | ||||
|       return values.PerDeviceDataset( | ||||
|           self._call_dataset_fn(dataset_fn), self._devices, | ||||
|           self._prefetch_on_device) | ||||
|           self._call_dataset_fn(dataset_fn), | ||||
|           self._devices, | ||||
|           self._prefetch_on_device, | ||||
|           source_device=device_util.resolve("/device:CPU:0")) | ||||
| 
 | ||||
|   # TODO(priyag): Deal with OutOfRange errors once b/111349762 is fixed. | ||||
|   def _run_steps_on_dataset(self, fn, iterator, iterations, | ||||
|  | ||||
| @ -300,9 +300,15 @@ class MirroredStrategyVariableCreationTest(test.TestCase): | ||||
| 
 | ||||
|     dist = mirrored_strategy.MirroredStrategy( | ||||
|         ["/device:GPU:0", "/device:CPU:0"]) | ||||
|     features = dist.distribute_dataset( | ||||
|         lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10) | ||||
|     ).make_one_shot_iterator().get_next() | ||||
|     ds = dist.distribute_dataset( | ||||
|         lambda: dataset_ops.Dataset.from_tensors([[1.]]).repeat(10)) | ||||
|     if context.executing_eagerly(): | ||||
|       iterator = ds.make_one_shot_iterator() | ||||
|     else: | ||||
|       iterator = ds.make_initializable_iterator() | ||||
|       self.evaluate([iterator.initializer]) | ||||
| 
 | ||||
|     features = iterator.get_next() | ||||
| 
 | ||||
|     with dist.scope(): | ||||
|       result = dist.call_for_each_tower( | ||||
|  | ||||
| @ -51,6 +51,7 @@ class Monitor(object): | ||||
|     else: | ||||
|       if session is None: | ||||
|         raise ValueError("Should provide a `session` in Graph mode.") | ||||
|       session.run(step_callable._iterator.initializer)  # pylint: disable=protected-access | ||||
|       self._run_step = session.make_callable(step_callable()) | ||||
|       session.run(variables.global_variables_initializer()) | ||||
| 
 | ||||
|  | ||||
| @ -42,8 +42,11 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): | ||||
|       model_fn, dataset_fn, layer = minimize_loss_example( | ||||
|           optimizer_fn, use_bias=True, use_callable_loss=use_callable_loss) | ||||
| 
 | ||||
|       iterator = distribution.distribute_dataset( | ||||
|           dataset_fn).make_one_shot_iterator() | ||||
|       ds = distribution.distribute_dataset(dataset_fn) | ||||
|       if context.executing_eagerly(): | ||||
|         iterator = ds.make_one_shot_iterator() | ||||
|       else: | ||||
|         iterator = ds.make_initializable_iterator() | ||||
| 
 | ||||
|       def run_step(): | ||||
|         return control_flow_ops.group(distribution.unwrap( | ||||
| @ -52,6 +55,7 @@ class MinimizeLossOptimizerV2Test(test.TestCase, parameterized.TestCase): | ||||
| 
 | ||||
|       if not context.executing_eagerly(): | ||||
|         with self.cached_session() as sess: | ||||
|           sess.run(iterator.initializer) | ||||
|           run_step = sess.make_callable(run_step()) | ||||
|         self.evaluate(variables.global_variables_initializer()) | ||||
| 
 | ||||
|  | ||||
| @ -1,229 +0,0 @@ | ||||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Extension of prefetching_ops to support more than one device.""" | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| import warnings | ||||
| 
 | ||||
| from tensorflow.contrib.data.python.ops import contrib_op_loader  # pylint: disable=unused-import | ||||
| from tensorflow.contrib.data.python.ops import gen_dataset_ops | ||||
| from tensorflow.contrib.data.python.ops import prefetching_ops | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.data.ops import iterator_ops | ||||
| from tensorflow.python.data.util import nest as data_nest | ||||
| from tensorflow.python.data.util import sparse | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.framework import function | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.util import nest | ||||
| 
 | ||||
| 
 | ||||
| # pylint: disable=protected-access | ||||
| class _PrefetchToDeviceIterator(object): | ||||
|   """A replacement for `tf.data.Iterator` that prefetches to another device. | ||||
| 
 | ||||
|   Args: | ||||
|     input_dataset: The input dataset. | ||||
|     one_shot: If true, we make a one shot iterator that's already initialized. | ||||
|     devices: Devices on which to prefetch. | ||||
|     buffer_size: Size of the prefetching buffer. | ||||
|     shared_name: (Optional.) If non-empty, the returned iterator will be | ||||
|         shared under the given name across multiple sessions that share the | ||||
|         same devices (e.g. when using a remote server). Only used if one_shot | ||||
|         is False. | ||||
| 
 | ||||
|   Returns: | ||||
|     An Iterator type object. | ||||
|   """ | ||||
| 
 | ||||
|   def __init__(self, | ||||
|                input_dataset, | ||||
|                one_shot, | ||||
|                devices, | ||||
|                buffer_size, | ||||
|                shared_name=None): | ||||
|     self._input_dataset = input_dataset | ||||
|     self._get_next_call_count = 0 | ||||
|     self._one_shot = one_shot | ||||
|     if shared_name is None: | ||||
|       shared_name = "" | ||||
|     self._devices = devices | ||||
| 
 | ||||
|     if self._one_shot: | ||||
|       self._input_iterator = input_dataset.make_one_shot_iterator() | ||||
|     else: | ||||
|       self._input_iterator = iterator_ops.Iterator.from_structure( | ||||
|           self._input_dataset.output_types, self._input_dataset.output_shapes, | ||||
|           shared_name, self._input_dataset.output_classes) | ||||
|     input_iterator_handle = self._input_iterator.string_handle() | ||||
| 
 | ||||
|     @function.Defun(dtypes.string) | ||||
|     def _prefetch_fn(handle): | ||||
|       """Prefetches one element from `input_iterator`.""" | ||||
|       remote_iterator = iterator_ops.Iterator.from_string_handle( | ||||
|           handle, self._input_iterator.output_types, | ||||
|           self._input_iterator.output_shapes, | ||||
|           self._input_iterator.output_classes) | ||||
|       ret = remote_iterator.get_next() | ||||
|       return nest.flatten(sparse.serialize_sparse_tensors(ret)) | ||||
| 
 | ||||
|     target_device = gen_dataset_ops.iterator_get_device( | ||||
|         self._input_iterator._iterator_resource) | ||||
|     self._buffering_resources = [] | ||||
|     for device in nest.flatten(self._devices): | ||||
|       with ops.device(device): | ||||
|         buffer_resource_handle = prefetching_ops.function_buffering_resource( | ||||
|             f=_prefetch_fn, | ||||
|             output_types=data_nest.flatten( | ||||
|                 sparse.as_dense_types(self._input_dataset.output_types, | ||||
|                                       self._input_dataset.output_classes)), | ||||
|             target_device=target_device, | ||||
|             string_arg=input_iterator_handle, | ||||
|             buffer_size=buffer_size, | ||||
|             shared_name=shared_name) | ||||
|         self._buffering_resources.append(buffer_resource_handle) | ||||
| 
 | ||||
|     if not self._one_shot: | ||||
|       reset_ops = [] | ||||
|       for buffer_resource in self._buffering_resources: | ||||
|         reset_ops.append( | ||||
|             prefetching_ops.function_buffering_resource_reset(buffer_resource)) | ||||
|       with ops.control_dependencies(reset_ops): | ||||
|         self._initializer = self._input_iterator.make_initializer( | ||||
|             self._input_dataset) | ||||
| 
 | ||||
|   def get_next(self, name=None): | ||||
|     """See `tf.data.Iterator.get_next`.""" | ||||
|     self._get_next_call_count += 1 | ||||
|     if self._get_next_call_count > iterator_ops.GET_NEXT_CALL_WARNING_THRESHOLD: | ||||
|       warnings.warn(iterator_ops.GET_NEXT_CALL_WARNING_MESSAGE) | ||||
| 
 | ||||
|     flat_result = [] | ||||
|     # TODO(priyag): This will fail if the input size (typically number of | ||||
|     # batches) is not divisible by number of devices. | ||||
|     # How do we handle that more gracefully / let the user know? | ||||
|     for buffer_resource in self._buffering_resources: | ||||
|       flat_ret = gen_dataset_ops.function_buffering_resource_get_next( | ||||
|           buffer_resource, | ||||
|           output_types=data_nest.flatten(sparse.as_dense_types( | ||||
|               self.output_types, self.output_classes)), name=name) | ||||
| 
 | ||||
|       ret = sparse.deserialize_sparse_tensors( | ||||
|           data_nest.pack_sequence_as(self.output_types, flat_ret), | ||||
|           self.output_types, self.output_shapes, self.output_classes) | ||||
| 
 | ||||
|       for tensor, shape in zip( | ||||
|           data_nest.flatten(ret), data_nest.flatten(self.output_shapes)): | ||||
|         if isinstance(tensor, ops.Tensor): | ||||
|           tensor.set_shape(shape) | ||||
|       flat_result.append(ret) | ||||
| 
 | ||||
|     return nest.pack_sequence_as(self._devices, flat_result) | ||||
| 
 | ||||
|   @property | ||||
|   def initializer(self): | ||||
|     if self._one_shot: | ||||
|       raise NotImplementedError("Can't initialize a one_shot_iterator") | ||||
|     return self._initializer | ||||
| 
 | ||||
|   @property | ||||
|   def output_classes(self): | ||||
|     return self._input_dataset.output_classes | ||||
| 
 | ||||
|   @property | ||||
|   def output_shapes(self): | ||||
|     return self._input_dataset.output_shapes | ||||
| 
 | ||||
|   @property | ||||
|   def output_types(self): | ||||
|     return self._input_dataset.output_types | ||||
| # pylint: enable=protected-access | ||||
| 
 | ||||
| 
 | ||||
| class _PrefetchToDeviceDataset(dataset_ops.UnaryDataset): | ||||
|   """A `Dataset` whose iterator prefetches elements to other device(s).""" | ||||
| 
 | ||||
|   def __init__(self, input_dataset, devices, buffer_size): | ||||
|     super(_PrefetchToDeviceDataset, self).__init__(input_dataset) | ||||
|     self._input_dataset = input_dataset | ||||
|     self._devices = devices | ||||
|     self._buffer_size = buffer_size if buffer_size is not None else 1 | ||||
| 
 | ||||
|   def make_one_shot_iterator(self): | ||||
|     return _PrefetchToDeviceIterator( | ||||
|         self._input_dataset, | ||||
|         one_shot=True, | ||||
|         devices=self._devices, | ||||
|         buffer_size=self._buffer_size) | ||||
| 
 | ||||
|   def make_initializable_iterator(self, shared_name=None): | ||||
|     if context.executing_eagerly(): | ||||
|       raise RuntimeError( | ||||
|           "make_initializable_iterator is not supported when eager " | ||||
|           "execution is enabled.") | ||||
| 
 | ||||
|     return _PrefetchToDeviceIterator( | ||||
|         self._input_dataset, | ||||
|         one_shot=False, | ||||
|         devices=self._devices, | ||||
|         buffer_size=self._buffer_size, | ||||
|         shared_name=shared_name) | ||||
| 
 | ||||
|   def _as_variant_tensor(self): | ||||
|     # TODO(mrry): Raise this error earlier (e.g. when one of the Dataset | ||||
|     # transformation methods is called. | ||||
|     # TODO(mrry): Investigate support for chaining further transformations after | ||||
|     # the prefetch, including GPU support. | ||||
|     raise NotImplementedError("`prefetch_to_devices()` must be the last " | ||||
|                               "transformation in a dataset pipeline.") | ||||
| 
 | ||||
|   # TODO(priyag): Fix the output types, shapes and classes to match the result | ||||
|   # of get_next (which has the additional nesting layer of devices now). | ||||
|   @property | ||||
|   def output_types(self): | ||||
|     return self._input_dataset.output_types | ||||
| 
 | ||||
|   @property | ||||
|   def output_shapes(self): | ||||
|     return self._input_dataset.output_shapes | ||||
| 
 | ||||
|   @property | ||||
|   def output_classes(self): | ||||
|     return self._input_dataset.output_classes | ||||
| 
 | ||||
| 
 | ||||
| def prefetch_to_devices(devices, buffer_size=None): | ||||
|   """A transformation that prefetches dataset values to the given `devices`. | ||||
| 
 | ||||
|   NOTE: Although the transformation creates a `tf.data.Dataset`, the | ||||
|   transformation must be the final `Dataset` in the input pipeline. | ||||
| 
 | ||||
|   Args: | ||||
|     devices: A nested structure of devices on which to prefetch the data. It can | ||||
|       be a single device name, or a tuple or list of device names. | ||||
|     buffer_size: (Optional.) The number of elements to buffer on each device. | ||||
|       Defaults to an automatically chosen value. | ||||
| 
 | ||||
|   Returns: | ||||
|     A `Dataset` transformation function, which can be passed to | ||||
|     `tf.data.Dataset.apply`. | ||||
|   """ | ||||
|   def _apply_fn(dataset): | ||||
|     return _PrefetchToDeviceDataset(dataset, devices, buffer_size) | ||||
| 
 | ||||
|   return _apply_fn | ||||
| @ -1,90 +0,0 @@ | ||||
| # Copyright 2017 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Tests for prefetching_ops_v2.""" | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.contrib.distribute.python import prefetching_ops_v2 | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.framework import errors | ||||
| from tensorflow.python.framework import test_util | ||||
| from tensorflow.python.platform import test | ||||
| 
 | ||||
| 
 | ||||
| class PrefetchingOpsV2Test(test.TestCase): | ||||
| 
 | ||||
|   def testPrefetchToOneDevice(self): | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     host_dataset = dataset_ops.Dataset.range(10) | ||||
|     device_dataset = host_dataset.apply( | ||||
|         prefetching_ops_v2.prefetch_to_devices("/gpu:0")) | ||||
| 
 | ||||
|     iterator = device_dataset.make_one_shot_iterator() | ||||
|     next_element = iterator.get_next() | ||||
| 
 | ||||
|     with self.cached_session() as sess: | ||||
|       for i in range(10): | ||||
|         self.assertEqual(i, sess.run(next_element)) | ||||
|       with self.assertRaises(errors.OutOfRangeError): | ||||
|         sess.run(next_element) | ||||
| 
 | ||||
|   def testPrefetchToTwoDevicesInAList(self): | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     host_dataset = dataset_ops.Dataset.range(10) | ||||
|     device_dataset = host_dataset.apply( | ||||
|         prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) | ||||
| 
 | ||||
|     iterator = device_dataset.make_one_shot_iterator() | ||||
|     next_element = iterator.get_next() | ||||
| 
 | ||||
|     output = [] | ||||
|     # TODO(rohanj): Modify test to go till the end of the dataset when we | ||||
|     # switch to MultiDeviceIterator. | ||||
|     with self.cached_session() as sess: | ||||
|       for _ in range(4): | ||||
|         result = sess.run(next_element) | ||||
|         self.assertEqual(2, len(result)) | ||||
|         output.extend(result) | ||||
|       self.assertEquals(set(range(8)), set(output)) | ||||
| 
 | ||||
|   def testPrefetchToTwoDevicesWithReinit(self): | ||||
|     if not test_util.is_gpu_available(): | ||||
|       self.skipTest("No GPU available") | ||||
| 
 | ||||
|     host_dataset = dataset_ops.Dataset.range(10) | ||||
|     device_dataset = host_dataset.apply( | ||||
|         prefetching_ops_v2.prefetch_to_devices(["/cpu:0", "/gpu:0"])) | ||||
| 
 | ||||
|     iterator = device_dataset.make_initializable_iterator() | ||||
|     next_element = iterator.get_next() | ||||
| 
 | ||||
|     # TODO(rohanj): Modify test to go till the end of the dataset when we | ||||
|     # switch to MultiDeviceIterator. | ||||
|     with self.cached_session() as sess: | ||||
|       sess.run(iterator.initializer) | ||||
|       for _ in range(4): | ||||
|         sess.run(next_element) | ||||
|       sess.run(iterator.initializer) | ||||
|       for _ in range(4): | ||||
|         sess.run(next_element) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
| @ -19,6 +19,7 @@ from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| from tensorflow.python.eager import backprop | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.training import optimizer as optimizer_lib | ||||
| 
 | ||||
| 
 | ||||
| @ -50,7 +51,11 @@ class StandardInputStep(Step): | ||||
|   def __init__(self, dataset_fn, distribution): | ||||
|     super(StandardInputStep, self).__init__(distribution) | ||||
|     self._distributed_input = distribution.distribute_dataset(dataset_fn) | ||||
|     self._iterator = self._distributed_input.make_one_shot_iterator() | ||||
|     if context.executing_eagerly(): | ||||
|       self._iterator = self._distributed_input.make_one_shot_iterator() | ||||
|     else: | ||||
|       # TODO(priyag): Expose initializer via some initializer property. | ||||
|       self._iterator = self._distributed_input.make_initializable_iterator() | ||||
| 
 | ||||
| 
 | ||||
| class StandardSingleLossStep(StandardInputStep): | ||||
|  | ||||
| @ -50,6 +50,7 @@ class SingleLossStepTest(test.TestCase, parameterized.TestCase): | ||||
|         run_step = single_loss_step | ||||
|       else: | ||||
|         with self.cached_session() as sess: | ||||
|           sess.run(single_loss_step._iterator.initializer) | ||||
|           run_step = sess.make_callable(single_loss_step()) | ||||
|       self.evaluate(variables.global_variables_initializer()) | ||||
| 
 | ||||
|  | ||||
| @ -26,7 +26,7 @@ import weakref | ||||
| import six | ||||
| 
 | ||||
| from tensorflow.contrib.distribute.python import input_ops | ||||
| from tensorflow.contrib.distribute.python import prefetching_ops_v2 | ||||
| from tensorflow.python.data.ops import multi_device_iterator_ops | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import device as tf_device | ||||
| from tensorflow.python.framework import ops | ||||
| @ -683,7 +683,7 @@ class PerDeviceDataIterator(object): | ||||
|   def get_next(self, name=None): | ||||
|     """Scatter the input across devices.""" | ||||
|     if self._prefetch_on_device: | ||||
|       data_list = self._iterator.get_next(name=name) | ||||
|       data_list = self._iterator.get_next() | ||||
|       index = dict(zip(self._devices, data_list)) | ||||
|     else: | ||||
|       batch = self._iterator.get_next(name=name) | ||||
| @ -703,21 +703,26 @@ class PerDeviceDataIterator(object): | ||||
| class PerDeviceDataset(object): | ||||
|   """Like `tf.data.Dataset` split devices, producing `PerDevice` data.""" | ||||
| 
 | ||||
|   def __init__(self, dataset, devices, prefetch_on_device=None): | ||||
|   def __init__( | ||||
|       self, | ||||
|       dataset, | ||||
|       devices, | ||||
|       prefetch_on_device=None, | ||||
|       source_device="/cpu:0", | ||||
|   ): | ||||
|     self._devices = devices | ||||
|     self._source_device = source_device if source_device is not None else "/cpu:0" | ||||
| 
 | ||||
|     # Default to using prefetching in graph mode, unless specified. | ||||
|     # TODO(priyag): Enable prefetching in eager mode. | ||||
|     # TODO(rohanj): Enable prefetching in eager mode. | ||||
|     self._prefetch_on_device = prefetch_on_device | ||||
|     if self._prefetch_on_device is None: | ||||
|       self._prefetch_on_device = not context.executing_eagerly() | ||||
|     assert not (self._prefetch_on_device and context.executing_eagerly()), ( | ||||
|         "Prefetching is only supported in graph mode currently") | ||||
| 
 | ||||
|     if self._prefetch_on_device: | ||||
|       self._dataset = dataset.apply( | ||||
|           prefetching_ops_v2.prefetch_to_devices(self._devices)) | ||||
|     else: | ||||
|     self._dataset = dataset | ||||
|     if not self._prefetch_on_device: | ||||
|       # TODO(priyag): If dropping remainder is not appropriate, find another | ||||
|       # approach to distributing the dataset when not possible to divide evenly. | ||||
|       # Possibly not an issue when we start using PartitionedDataset. | ||||
| @ -725,15 +730,33 @@ class PerDeviceDataset(object): | ||||
| 
 | ||||
|   def make_one_shot_iterator(self): | ||||
|     """Get a one time use iterator for the distributed PerDeviceDataset.""" | ||||
|     # Graph mode prefetching with one shot iterator is disabled. | ||||
|     if not context.executing_eagerly(): | ||||
|       raise ValueError("Cannot create a one shot iterator. Please use " | ||||
|                        "`make_initializable_iterator()` instead.") | ||||
|     # Eager mode prefetching would error out in constructor. Only remaining | ||||
|     # cases are non-prefetching eager / graph mode. We delegate to | ||||
|     # PerDeviceDataIterator to handle them. | ||||
|     dataset_iterator = self._dataset.make_one_shot_iterator() | ||||
|     return PerDeviceDataIterator( | ||||
|         dataset_iterator, self._devices, self._prefetch_on_device) | ||||
|         dataset_iterator, self._devices, prefetch_on_device=False) | ||||
| 
 | ||||
|   def make_initializable_iterator(self): | ||||
|     """Get an initializable iterator for the distributed PerDeviceDataset.""" | ||||
|     dataset_iterator = self._dataset.make_initializable_iterator() | ||||
|     # Eager mode generates already initialized iterators. Hence we cannot create | ||||
|     # an initializable iterator. | ||||
|     if context.executing_eagerly(): | ||||
|       raise ValueError("Cannot create initializable iterator in Eager mode. " | ||||
|                        "Please use `make_one_shot_iterator` instead.") | ||||
|     if self._prefetch_on_device: | ||||
|       dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( | ||||
|           self._dataset, self._devices, source_device=self._source_device) | ||||
|     else: | ||||
|       dataset_iterator = self._dataset.make_initializable_iterator() | ||||
|     return PerDeviceDataIterator( | ||||
|         dataset_iterator, self._devices, self._prefetch_on_device) | ||||
|         dataset_iterator, | ||||
|         self._devices, | ||||
|         prefetch_on_device=self._prefetch_on_device) | ||||
| 
 | ||||
| 
 | ||||
| class MultiWorkerDataIterator(object): | ||||
| @ -813,7 +836,10 @@ class MultiWorkerDataset(object): | ||||
|         worker_input = input_ops.auto_shard_dataset( | ||||
|             worker_input, len(worker_device_map), i) | ||||
|         self._datasets[worker] = PerDeviceDataset( | ||||
|             worker_input, worker_devices, prefetch_on_device=prefetch_on_device) | ||||
|             worker_input, | ||||
|             worker_devices, | ||||
|             source_device=worker, | ||||
|             prefetch_on_device=prefetch_on_device) | ||||
| 
 | ||||
|   def make_one_shot_iterator(self): | ||||
|     iterators = {} | ||||
|  | ||||
| @ -349,7 +349,11 @@ class PerDeviceDatasetTest(test.TestCase): | ||||
|   def _test_iterator_no_prefetch(self, devices, dataset, expected_values): | ||||
|     per_device_dataset = values.PerDeviceDataset( | ||||
|         dataset, devices, prefetch_on_device=False) | ||||
|     iterator = per_device_dataset.make_one_shot_iterator() | ||||
|     if context.executing_eagerly(): | ||||
|       iterator = per_device_dataset.make_one_shot_iterator() | ||||
|     else: | ||||
|       iterator = per_device_dataset.make_initializable_iterator() | ||||
|       self.evaluate([iterator.initializer]) | ||||
| 
 | ||||
|     for expected_value in expected_values: | ||||
|       next_element = iterator.get_next() | ||||
| @ -366,20 +370,14 @@ class PerDeviceDatasetTest(test.TestCase): | ||||
|     if not context.executing_eagerly(): | ||||
|       per_device_dataset = values.PerDeviceDataset( | ||||
|           dataset, devices, prefetch_on_device=True) | ||||
|       iterator = per_device_dataset.make_one_shot_iterator() | ||||
|       iterator = per_device_dataset.make_initializable_iterator() | ||||
|       self.evaluate([iterator.initializer]) | ||||
| 
 | ||||
|       # With prefetching, we cannot guarantee which input ends up on which | ||||
|       # device, so we verify that the complete set seen on all devices is | ||||
|       # correct, and equal numbers are distributed to each device. | ||||
|       combined_actual = [] | ||||
|       combined_expected = [] | ||||
|       for expected_value in expected_values: | ||||
|         next_element = iterator.get_next() | ||||
|         combined_actual.extend(self.evaluate([ | ||||
|             values.select_device(d, next_element) for d in devices])) | ||||
|         combined_expected.extend(expected_value) | ||||
| 
 | ||||
|       self.assertEqual(set(combined_expected), set(combined_actual)) | ||||
|         computed_value = self.evaluate( | ||||
|             [values.select_device(d, next_element) for d in devices]) | ||||
|         self.assertEqual(expected_value, computed_value) | ||||
| 
 | ||||
|       with self.assertRaises(errors.OutOfRangeError): | ||||
|         next_element = iterator.get_next() | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user