Clean up _set_prefetch_on_host method from TPUStrategy.
PiperOrigin-RevId: 317672388 Change-Id: Icb03401a792df4d071659f101ba06edf59010577
This commit is contained in:
parent
22f6939be9
commit
eb32bf1ae4
|
@ -454,10 +454,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
|||
))
|
||||
def testDistributeDatasetHostPrefetch(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_dataset(
|
||||
get_dataset_from_tensor_slices(data).batch(2)))
|
||||
get_dataset_from_tensor_slices(data).batch(2),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
local_results = distribution.experimental_local_results(
|
||||
input_iterator.get_next())
|
||||
|
@ -473,10 +474,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
|||
))
|
||||
def testDistributeDatasetFunctionHostPrefetch(self, distribution):
|
||||
data = [5., 6., 7., 8.]
|
||||
distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
||||
input_iterator = iter(
|
||||
distribution.experimental_distribute_datasets_from_function(
|
||||
lambda _: get_dataset_from_tensor_slices(data)))
|
||||
lambda _: get_dataset_from_tensor_slices(data),
|
||||
distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
local_results = distribution.experimental_local_results(
|
||||
input_iterator.get_next())
|
||||
|
|
|
@ -47,12 +47,14 @@ from tensorflow.python.framework import constant_op
|
|||
from tensorflow.python.framework import device_spec
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.framework import tensor_util
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import resource_variable_ops
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
|
||||
from tensorflow.python.tpu import tpu
|
||||
from tensorflow.python.tpu import tpu_strategy_util
|
||||
|
@ -515,7 +517,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||
self._require_static_shapes = True
|
||||
|
||||
self.experimental_enable_get_next_as_optional = True
|
||||
self._prefetch_to_device = True
|
||||
|
||||
self._logical_device_stack = [0]
|
||||
|
||||
|
@ -527,16 +528,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||
context.async_wait()
|
||||
atexit.register(async_wait)
|
||||
|
||||
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
|
||||
# a dataset to multiple devices exists.
|
||||
# If value is true, this forces prefetch of data to the host's memeory rather
|
||||
# than the individual TPU device's memory. This is needed when using for TPU
|
||||
# Embeddings as a) sparse tensors cannot be prefetched to the TPU device
|
||||
# memory and b) TPU Embedding enqueue operation are CPU ops and this avoids
|
||||
# a copy back to the host for dense tensors
|
||||
def _set_prefetch_on_host(self, value):
|
||||
self._prefetch_to_device = not value
|
||||
|
||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||
distribute_utils. validate_colocate(colocate_with_variable, self)
|
||||
|
||||
|
@ -575,17 +566,32 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||
session)
|
||||
|
||||
def _get_input_workers(self, options):
|
||||
prefetch_to_device = self._prefetch_to_device
|
||||
if options:
|
||||
prefetch_to_device = options.experimental_prefetch_to_device
|
||||
if prefetch_to_device:
|
||||
if not options or options.experimental_prefetch_to_device:
|
||||
return input_lib.InputWorkers(
|
||||
tuple(self._device_input_worker_devices.items()))
|
||||
else:
|
||||
return input_lib.InputWorkers(
|
||||
tuple(self._host_input_worker_devices.items()))
|
||||
|
||||
def _check_spec(self, element_spec):
|
||||
if isinstance(element_spec, values.PerReplicaSpec):
|
||||
element_spec = element_spec._component_specs # pylint: disable=protected-access
|
||||
specs = nest.flatten_with_joined_string_paths(element_spec)
|
||||
for path, spec in specs:
|
||||
if isinstance(spec, (sparse_tensor.SparseTensorSpec,
|
||||
ragged_tensor.RaggedTensorSpec)):
|
||||
raise ValueError(
|
||||
"Found tensor {} with spec {}. TPUStrategy does not support "
|
||||
"distributed datasets with device prefetch when using sparse or "
|
||||
"ragged tensors. If you indend to use sparse or ragged tensors, "
|
||||
"please pass a tf.distribute.InputOptions object with "
|
||||
"experimental_prefetch_to_device set to False to your dataset "
|
||||
"distribution function.".format(path, type(spec)))
|
||||
|
||||
def _experimental_distribute_dataset(self, dataset, options):
|
||||
if options is None or options.experimental_prefetch_to_device:
|
||||
self._check_spec(dataset.element_spec)
|
||||
|
||||
return input_lib.get_distributed_dataset(
|
||||
dataset,
|
||||
self._get_input_workers(options),
|
||||
|
@ -603,12 +609,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
|||
input_pipeline_id=i,
|
||||
num_replicas_in_sync=self._num_replicas_in_sync))
|
||||
|
||||
return input_lib.get_distributed_datasets_from_function(
|
||||
distributed_dataset = input_lib.get_distributed_datasets_from_function(
|
||||
dataset_fn,
|
||||
input_workers,
|
||||
input_contexts,
|
||||
self._container_strategy())
|
||||
|
||||
# We can only check after the dataset_fn is called.
|
||||
if options is None or options.experimental_prefetch_to_device:
|
||||
self._check_spec(distributed_dataset.element_spec)
|
||||
return distributed_dataset
|
||||
|
||||
def _experimental_distribute_values_from_function(self, value_fn):
|
||||
per_replica_values = []
|
||||
for replica_id in range(self._num_replicas_in_sync):
|
||||
|
|
|
@ -44,6 +44,7 @@ from tensorflow.python.ops import embedding_ops
|
|||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.ops.ragged import ragged_tensor
|
||||
from tensorflow.python.platform import flags
|
||||
from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
||||
|
@ -475,9 +476,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||
|
||||
return dataset.map(make_sparse)
|
||||
|
||||
strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
||||
dataset = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
dataset_fn,
|
||||
distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
result = sparse_lookup(dataset)
|
||||
self.assertAllEqual(result,
|
||||
|
@ -520,9 +523,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
|||
|
||||
return dataset.map(make_sparse)
|
||||
|
||||
strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
||||
dataset = iter(
|
||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
strategy.experimental_distribute_datasets_from_function(
|
||||
dataset_fn,
|
||||
options=distribute_lib.InputOptions(
|
||||
experimental_prefetch_to_device=False)))
|
||||
|
||||
result = sparse_lookup(dataset)
|
||||
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
|
||||
|
@ -594,5 +599,61 @@ class TPUStrategyDataPrefetchTest(test.TestCase):
|
|||
dataset_item.values[0].device)
|
||||
self.assertEqual(dataset_location.device_type, "CPU")
|
||||
|
||||
def test_prefetch_to_device_sparse_dataset(self):
|
||||
strategy = get_tpu_strategy()
|
||||
# Values here aren't important.
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
|
||||
values=[1, 2, 3],
|
||||
dense_shape=[2, 2]))
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.batch(strategy.num_replicas_in_sync)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||
iter(strategy.experimental_distribute_dataset(dataset))
|
||||
|
||||
def test_prefetch_to_device_ragged_dataset(self):
|
||||
strategy = get_tpu_strategy()
|
||||
# Values here aren't important.
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[1, 2, 3],
|
||||
row_splits=[0, 2, 3]))
|
||||
dataset = dataset.repeat()
|
||||
dataset = dataset.batch(strategy.num_replicas_in_sync)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||
iter(strategy.experimental_distribute_dataset(dataset))
|
||||
|
||||
def test_prefetch_to_device_sparse_dataset_fn(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
# Values here aren't important.
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
|
||||
values=[1, 2, 3],
|
||||
dense_shape=[2, 2]))
|
||||
dataset = dataset.repeat()
|
||||
return dataset.batch(strategy.num_replicas_in_sync)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
def test_prefetch_to_device_ragged_dataset_fn(self):
|
||||
strategy = get_tpu_strategy()
|
||||
def dataset_fn(ctx):
|
||||
del ctx
|
||||
# Values here aren't important.
|
||||
dataset = dataset_ops.Dataset.from_tensors(
|
||||
ragged_tensor.RaggedTensor.from_row_splits(
|
||||
values=[1, 2, 3],
|
||||
row_splits=[0, 2, 3]))
|
||||
dataset = dataset.repeat()
|
||||
return dataset.batch(strategy.num_replicas_in_sync)
|
||||
|
||||
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
|
Loading…
Reference in New Issue