From eb32bf1ae46571375b00c9e3e146dad7706286ce Mon Sep 17 00:00:00 2001 From: Bruce Fontaine Date: Mon, 22 Jun 2020 09:43:37 -0700 Subject: [PATCH] Clean up _set_prefetch_on_host method from TPUStrategy. PiperOrigin-RevId: 317672388 Change-Id: Icb03401a792df4d071659f101ba06edf59010577 --- .../custom_training_loop_input_test.py | 10 +-- tensorflow/python/distribute/tpu_strategy.py | 43 +++++++----- .../python/distribute/tpu_strategy_test.py | 69 +++++++++++++++++-- 3 files changed, 98 insertions(+), 24 deletions(-) diff --git a/tensorflow/python/distribute/custom_training_loop_input_test.py b/tensorflow/python/distribute/custom_training_loop_input_test.py index 5660b5839ce..748cb7834fc 100644 --- a/tensorflow/python/distribute/custom_training_loop_input_test.py +++ b/tensorflow/python/distribute/custom_training_loop_input_test.py @@ -454,10 +454,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, )) def testDistributeDatasetHostPrefetch(self, distribution): data = [5., 6., 7., 8.] - distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access input_iterator = iter( distribution.experimental_distribute_dataset( - get_dataset_from_tensor_slices(data).batch(2))) + get_dataset_from_tensor_slices(data).batch(2), + distribute_lib.InputOptions( + experimental_prefetch_to_device=False))) local_results = distribution.experimental_local_results( input_iterator.get_next()) @@ -473,10 +474,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase, )) def testDistributeDatasetFunctionHostPrefetch(self, distribution): data = [5., 6., 7., 8.] - distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access input_iterator = iter( distribution.experimental_distribute_datasets_from_function( - lambda _: get_dataset_from_tensor_slices(data))) + lambda _: get_dataset_from_tensor_slices(data), + distribute_lib.InputOptions( + experimental_prefetch_to_device=False))) local_results = distribution.experimental_local_results( input_iterator.get_next()) diff --git a/tensorflow/python/distribute/tpu_strategy.py b/tensorflow/python/distribute/tpu_strategy.py index dcd1671841f..4b3c4be0ccd 100644 --- a/tensorflow/python/distribute/tpu_strategy.py +++ b/tensorflow/python/distribute/tpu_strategy.py @@ -47,12 +47,14 @@ from tensorflow.python.framework import constant_op from tensorflow.python.framework import device_spec from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import sparse_tensor from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu_strategy_util @@ -515,7 +517,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): self._require_static_shapes = True self.experimental_enable_get_next_as_optional = True - self._prefetch_to_device = True self._logical_device_stack = [0] @@ -527,16 +528,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): context.async_wait() atexit.register(async_wait) - # TODO(bfontain): Remove once a proper dataset API exists for prefetching - # a dataset to multiple devices exists. - # If value is true, this forces prefetch of data to the host's memeory rather - # than the individual TPU device's memory. This is needed when using for TPU - # Embeddings as a) sparse tensors cannot be prefetched to the TPU device - # memory and b) TPU Embedding enqueue operation are CPU ops and this avoids - # a copy back to the host for dense tensors - def _set_prefetch_on_host(self, value): - self._prefetch_to_device = not value - def _validate_colocate_with_variable(self, colocate_with_variable): distribute_utils. validate_colocate(colocate_with_variable, self) @@ -575,17 +566,32 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): session) def _get_input_workers(self, options): - prefetch_to_device = self._prefetch_to_device - if options: - prefetch_to_device = options.experimental_prefetch_to_device - if prefetch_to_device: + if not options or options.experimental_prefetch_to_device: return input_lib.InputWorkers( tuple(self._device_input_worker_devices.items())) else: return input_lib.InputWorkers( tuple(self._host_input_worker_devices.items())) + def _check_spec(self, element_spec): + if isinstance(element_spec, values.PerReplicaSpec): + element_spec = element_spec._component_specs # pylint: disable=protected-access + specs = nest.flatten_with_joined_string_paths(element_spec) + for path, spec in specs: + if isinstance(spec, (sparse_tensor.SparseTensorSpec, + ragged_tensor.RaggedTensorSpec)): + raise ValueError( + "Found tensor {} with spec {}. TPUStrategy does not support " + "distributed datasets with device prefetch when using sparse or " + "ragged tensors. If you indend to use sparse or ragged tensors, " + "please pass a tf.distribute.InputOptions object with " + "experimental_prefetch_to_device set to False to your dataset " + "distribution function.".format(path, type(spec))) + def _experimental_distribute_dataset(self, dataset, options): + if options is None or options.experimental_prefetch_to_device: + self._check_spec(dataset.element_spec) + return input_lib.get_distributed_dataset( dataset, self._get_input_workers(options), @@ -603,12 +609,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1): input_pipeline_id=i, num_replicas_in_sync=self._num_replicas_in_sync)) - return input_lib.get_distributed_datasets_from_function( + distributed_dataset = input_lib.get_distributed_datasets_from_function( dataset_fn, input_workers, input_contexts, self._container_strategy()) + # We can only check after the dataset_fn is called. + if options is None or options.experimental_prefetch_to_device: + self._check_spec(distributed_dataset.element_spec) + return distributed_dataset + def _experimental_distribute_values_from_function(self, value_fn): per_replica_values = [] for replica_id in range(self._num_replicas_in_sync): diff --git a/tensorflow/python/distribute/tpu_strategy_test.py b/tensorflow/python/distribute/tpu_strategy_test.py index 4070336aae8..5e47e750d87 100644 --- a/tensorflow/python/distribute/tpu_strategy_test.py +++ b/tensorflow/python/distribute/tpu_strategy_test.py @@ -44,6 +44,7 @@ from tensorflow.python.ops import embedding_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables +from tensorflow.python.ops.ragged import ragged_tensor from tensorflow.python.platform import flags from tensorflow.python.platform import tf_logging as logging from tensorflow.python.tpu import device_assignment as device_assignment_lib @@ -475,9 +476,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return dataset.map(make_sparse) - strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access dataset = iter( - strategy.experimental_distribute_datasets_from_function(dataset_fn)) + strategy.experimental_distribute_datasets_from_function( + dataset_fn, + distribute_lib.InputOptions( + experimental_prefetch_to_device=False))) result = sparse_lookup(dataset) self.assertAllEqual(result, @@ -520,9 +523,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase): return dataset.map(make_sparse) - strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access dataset = iter( - strategy.experimental_distribute_datasets_from_function(dataset_fn)) + strategy.experimental_distribute_datasets_from_function( + dataset_fn, + options=distribute_lib.InputOptions( + experimental_prefetch_to_device=False))) result = sparse_lookup(dataset) self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]]) @@ -594,5 +599,61 @@ class TPUStrategyDataPrefetchTest(test.TestCase): dataset_item.values[0].device) self.assertEqual(dataset_location.device_type, "CPU") + def test_prefetch_to_device_sparse_dataset(self): + strategy = get_tpu_strategy() + # Values here aren't important. + dataset = dataset_ops.Dataset.from_tensors( + sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], + values=[1, 2, 3], + dense_shape=[2, 2])) + dataset = dataset.repeat() + dataset = dataset.batch(strategy.num_replicas_in_sync) + + with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): + iter(strategy.experimental_distribute_dataset(dataset)) + + def test_prefetch_to_device_ragged_dataset(self): + strategy = get_tpu_strategy() + # Values here aren't important. + dataset = dataset_ops.Dataset.from_tensors( + ragged_tensor.RaggedTensor.from_row_splits( + values=[1, 2, 3], + row_splits=[0, 2, 3])) + dataset = dataset.repeat() + dataset = dataset.batch(strategy.num_replicas_in_sync) + + with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): + iter(strategy.experimental_distribute_dataset(dataset)) + + def test_prefetch_to_device_sparse_dataset_fn(self): + strategy = get_tpu_strategy() + def dataset_fn(ctx): + del ctx + # Values here aren't important. + dataset = dataset_ops.Dataset.from_tensors( + sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]], + values=[1, 2, 3], + dense_shape=[2, 2])) + dataset = dataset.repeat() + return dataset.batch(strategy.num_replicas_in_sync) + + with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): + iter(strategy.experimental_distribute_datasets_from_function(dataset_fn)) + + def test_prefetch_to_device_ragged_dataset_fn(self): + strategy = get_tpu_strategy() + def dataset_fn(ctx): + del ctx + # Values here aren't important. + dataset = dataset_ops.Dataset.from_tensors( + ragged_tensor.RaggedTensor.from_row_splits( + values=[1, 2, 3], + row_splits=[0, 2, 3])) + dataset = dataset.repeat() + return dataset.batch(strategy.num_replicas_in_sync) + + with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"): + iter(strategy.experimental_distribute_datasets_from_function(dataset_fn)) + if __name__ == "__main__": test.main()