Clean up _set_prefetch_on_host method from TPUStrategy.

PiperOrigin-RevId: 317672388
Change-Id: Icb03401a792df4d071659f101ba06edf59010577
This commit is contained in:
Bruce Fontaine 2020-06-22 09:43:37 -07:00 committed by TensorFlower Gardener
parent 22f6939be9
commit eb32bf1ae4
3 changed files with 98 additions and 24 deletions

View File

@ -454,10 +454,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
)) ))
def testDistributeDatasetHostPrefetch(self, distribution): def testDistributeDatasetHostPrefetch(self, distribution):
data = [5., 6., 7., 8.] data = [5., 6., 7., 8.]
distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
input_iterator = iter( input_iterator = iter(
distribution.experimental_distribute_dataset( distribution.experimental_distribute_dataset(
get_dataset_from_tensor_slices(data).batch(2))) get_dataset_from_tensor_slices(data).batch(2),
distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
local_results = distribution.experimental_local_results( local_results = distribution.experimental_local_results(
input_iterator.get_next()) input_iterator.get_next())
@ -473,10 +474,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
)) ))
def testDistributeDatasetFunctionHostPrefetch(self, distribution): def testDistributeDatasetFunctionHostPrefetch(self, distribution):
data = [5., 6., 7., 8.] data = [5., 6., 7., 8.]
distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
input_iterator = iter( input_iterator = iter(
distribution.experimental_distribute_datasets_from_function( distribution.experimental_distribute_datasets_from_function(
lambda _: get_dataset_from_tensor_slices(data))) lambda _: get_dataset_from_tensor_slices(data),
distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
local_results = distribution.experimental_local_results( local_results = distribution.experimental_local_results(
input_iterator.get_next()) input_iterator.get_next())

View File

@ -47,12 +47,14 @@ from tensorflow.python.framework import constant_op
from tensorflow.python.framework import device_spec from tensorflow.python.framework import device_spec
from tensorflow.python.framework import dtypes from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor
from tensorflow.python.framework import tensor_shape from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
from tensorflow.python.tpu import tpu from tensorflow.python.tpu import tpu
from tensorflow.python.tpu import tpu_strategy_util from tensorflow.python.tpu import tpu_strategy_util
@ -515,7 +517,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
self._require_static_shapes = True self._require_static_shapes = True
self.experimental_enable_get_next_as_optional = True self.experimental_enable_get_next_as_optional = True
self._prefetch_to_device = True
self._logical_device_stack = [0] self._logical_device_stack = [0]
@ -527,16 +528,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
context.async_wait() context.async_wait()
atexit.register(async_wait) atexit.register(async_wait)
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
# a dataset to multiple devices exists.
# If value is true, this forces prefetch of data to the host's memeory rather
# than the individual TPU device's memory. This is needed when using for TPU
# Embeddings as a) sparse tensors cannot be prefetched to the TPU device
# memory and b) TPU Embedding enqueue operation are CPU ops and this avoids
# a copy back to the host for dense tensors
def _set_prefetch_on_host(self, value):
self._prefetch_to_device = not value
def _validate_colocate_with_variable(self, colocate_with_variable): def _validate_colocate_with_variable(self, colocate_with_variable):
distribute_utils. validate_colocate(colocate_with_variable, self) distribute_utils. validate_colocate(colocate_with_variable, self)
@ -575,17 +566,32 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
session) session)
def _get_input_workers(self, options): def _get_input_workers(self, options):
prefetch_to_device = self._prefetch_to_device if not options or options.experimental_prefetch_to_device:
if options:
prefetch_to_device = options.experimental_prefetch_to_device
if prefetch_to_device:
return input_lib.InputWorkers( return input_lib.InputWorkers(
tuple(self._device_input_worker_devices.items())) tuple(self._device_input_worker_devices.items()))
else: else:
return input_lib.InputWorkers( return input_lib.InputWorkers(
tuple(self._host_input_worker_devices.items())) tuple(self._host_input_worker_devices.items()))
def _check_spec(self, element_spec):
if isinstance(element_spec, values.PerReplicaSpec):
element_spec = element_spec._component_specs # pylint: disable=protected-access
specs = nest.flatten_with_joined_string_paths(element_spec)
for path, spec in specs:
if isinstance(spec, (sparse_tensor.SparseTensorSpec,
ragged_tensor.RaggedTensorSpec)):
raise ValueError(
"Found tensor {} with spec {}. TPUStrategy does not support "
"distributed datasets with device prefetch when using sparse or "
"ragged tensors. If you indend to use sparse or ragged tensors, "
"please pass a tf.distribute.InputOptions object with "
"experimental_prefetch_to_device set to False to your dataset "
"distribution function.".format(path, type(spec)))
def _experimental_distribute_dataset(self, dataset, options): def _experimental_distribute_dataset(self, dataset, options):
if options is None or options.experimental_prefetch_to_device:
self._check_spec(dataset.element_spec)
return input_lib.get_distributed_dataset( return input_lib.get_distributed_dataset(
dataset, dataset,
self._get_input_workers(options), self._get_input_workers(options),
@ -603,12 +609,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
input_pipeline_id=i, input_pipeline_id=i,
num_replicas_in_sync=self._num_replicas_in_sync)) num_replicas_in_sync=self._num_replicas_in_sync))
return input_lib.get_distributed_datasets_from_function( distributed_dataset = input_lib.get_distributed_datasets_from_function(
dataset_fn, dataset_fn,
input_workers, input_workers,
input_contexts, input_contexts,
self._container_strategy()) self._container_strategy())
# We can only check after the dataset_fn is called.
if options is None or options.experimental_prefetch_to_device:
self._check_spec(distributed_dataset.element_spec)
return distributed_dataset
def _experimental_distribute_values_from_function(self, value_fn): def _experimental_distribute_values_from_function(self, value_fn):
per_replica_values = [] per_replica_values = []
for replica_id in range(self._num_replicas_in_sync): for replica_id in range(self._num_replicas_in_sync):

View File

@ -44,6 +44,7 @@ from tensorflow.python.ops import embedding_ops
from tensorflow.python.ops import math_ops from tensorflow.python.ops import math_ops
from tensorflow.python.ops import random_ops from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
from tensorflow.python.ops.ragged import ragged_tensor
from tensorflow.python.platform import flags from tensorflow.python.platform import flags
from tensorflow.python.platform import tf_logging as logging from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.tpu import device_assignment as device_assignment_lib from tensorflow.python.tpu import device_assignment as device_assignment_lib
@ -475,9 +476,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
return dataset.map(make_sparse) return dataset.map(make_sparse)
strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
dataset = iter( dataset = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn)) strategy.experimental_distribute_datasets_from_function(
dataset_fn,
distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
result = sparse_lookup(dataset) result = sparse_lookup(dataset)
self.assertAllEqual(result, self.assertAllEqual(result,
@ -520,9 +523,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
return dataset.map(make_sparse) return dataset.map(make_sparse)
strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
dataset = iter( dataset = iter(
strategy.experimental_distribute_datasets_from_function(dataset_fn)) strategy.experimental_distribute_datasets_from_function(
dataset_fn,
options=distribute_lib.InputOptions(
experimental_prefetch_to_device=False)))
result = sparse_lookup(dataset) result = sparse_lookup(dataset)
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]]) self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
@ -594,5 +599,61 @@ class TPUStrategyDataPrefetchTest(test.TestCase):
dataset_item.values[0].device) dataset_item.values[0].device)
self.assertEqual(dataset_location.device_type, "CPU") self.assertEqual(dataset_location.device_type, "CPU")
def test_prefetch_to_device_sparse_dataset(self):
strategy = get_tpu_strategy()
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
values=[1, 2, 3],
dense_shape=[2, 2]))
dataset = dataset.repeat()
dataset = dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_dataset(dataset))
def test_prefetch_to_device_ragged_dataset(self):
strategy = get_tpu_strategy()
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3],
row_splits=[0, 2, 3]))
dataset = dataset.repeat()
dataset = dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_dataset(dataset))
def test_prefetch_to_device_sparse_dataset_fn(self):
strategy = get_tpu_strategy()
def dataset_fn(ctx):
del ctx
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
values=[1, 2, 3],
dense_shape=[2, 2]))
dataset = dataset.repeat()
return dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
def test_prefetch_to_device_ragged_dataset_fn(self):
strategy = get_tpu_strategy()
def dataset_fn(ctx):
del ctx
# Values here aren't important.
dataset = dataset_ops.Dataset.from_tensors(
ragged_tensor.RaggedTensor.from_row_splits(
values=[1, 2, 3],
row_splits=[0, 2, 3]))
dataset = dataset.repeat()
return dataset.batch(strategy.num_replicas_in_sync)
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
if __name__ == "__main__": if __name__ == "__main__":
test.main() test.main()