Clean up _set_prefetch_on_host method from TPUStrategy.
PiperOrigin-RevId: 317672388 Change-Id: Icb03401a792df4d071659f101ba06edf59010577
This commit is contained in:
parent
22f6939be9
commit
eb32bf1ae4
|
@ -454,10 +454,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||||
))
|
))
|
||||||
def testDistributeDatasetHostPrefetch(self, distribution):
|
def testDistributeDatasetHostPrefetch(self, distribution):
|
||||||
data = [5., 6., 7., 8.]
|
data = [5., 6., 7., 8.]
|
||||||
distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
|
||||||
input_iterator = iter(
|
input_iterator = iter(
|
||||||
distribution.experimental_distribute_dataset(
|
distribution.experimental_distribute_dataset(
|
||||||
get_dataset_from_tensor_slices(data).batch(2)))
|
get_dataset_from_tensor_slices(data).batch(2),
|
||||||
|
distribute_lib.InputOptions(
|
||||||
|
experimental_prefetch_to_device=False)))
|
||||||
|
|
||||||
local_results = distribution.experimental_local_results(
|
local_results = distribution.experimental_local_results(
|
||||||
input_iterator.get_next())
|
input_iterator.get_next())
|
||||||
|
@ -473,10 +474,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
|
||||||
))
|
))
|
||||||
def testDistributeDatasetFunctionHostPrefetch(self, distribution):
|
def testDistributeDatasetFunctionHostPrefetch(self, distribution):
|
||||||
data = [5., 6., 7., 8.]
|
data = [5., 6., 7., 8.]
|
||||||
distribution.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
|
||||||
input_iterator = iter(
|
input_iterator = iter(
|
||||||
distribution.experimental_distribute_datasets_from_function(
|
distribution.experimental_distribute_datasets_from_function(
|
||||||
lambda _: get_dataset_from_tensor_slices(data)))
|
lambda _: get_dataset_from_tensor_slices(data),
|
||||||
|
distribute_lib.InputOptions(
|
||||||
|
experimental_prefetch_to_device=False)))
|
||||||
|
|
||||||
local_results = distribution.experimental_local_results(
|
local_results = distribution.experimental_local_results(
|
||||||
input_iterator.get_next())
|
input_iterator.get_next())
|
||||||
|
|
|
@ -47,12 +47,14 @@ from tensorflow.python.framework import constant_op
|
||||||
from tensorflow.python.framework import device_spec
|
from tensorflow.python.framework import device_spec
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.framework import sparse_tensor
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.framework import tensor_util
|
from tensorflow.python.framework import tensor_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import control_flow_ops
|
from tensorflow.python.ops import control_flow_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
|
from tensorflow.python.tpu import device_assignment as device_assignment_lib # pylint: disable=unused-import
|
||||||
from tensorflow.python.tpu import tpu
|
from tensorflow.python.tpu import tpu
|
||||||
from tensorflow.python.tpu import tpu_strategy_util
|
from tensorflow.python.tpu import tpu_strategy_util
|
||||||
|
@ -515,7 +517,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||||
self._require_static_shapes = True
|
self._require_static_shapes = True
|
||||||
|
|
||||||
self.experimental_enable_get_next_as_optional = True
|
self.experimental_enable_get_next_as_optional = True
|
||||||
self._prefetch_to_device = True
|
|
||||||
|
|
||||||
self._logical_device_stack = [0]
|
self._logical_device_stack = [0]
|
||||||
|
|
||||||
|
@ -527,16 +528,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||||
context.async_wait()
|
context.async_wait()
|
||||||
atexit.register(async_wait)
|
atexit.register(async_wait)
|
||||||
|
|
||||||
# TODO(bfontain): Remove once a proper dataset API exists for prefetching
|
|
||||||
# a dataset to multiple devices exists.
|
|
||||||
# If value is true, this forces prefetch of data to the host's memeory rather
|
|
||||||
# than the individual TPU device's memory. This is needed when using for TPU
|
|
||||||
# Embeddings as a) sparse tensors cannot be prefetched to the TPU device
|
|
||||||
# memory and b) TPU Embedding enqueue operation are CPU ops and this avoids
|
|
||||||
# a copy back to the host for dense tensors
|
|
||||||
def _set_prefetch_on_host(self, value):
|
|
||||||
self._prefetch_to_device = not value
|
|
||||||
|
|
||||||
def _validate_colocate_with_variable(self, colocate_with_variable):
|
def _validate_colocate_with_variable(self, colocate_with_variable):
|
||||||
distribute_utils. validate_colocate(colocate_with_variable, self)
|
distribute_utils. validate_colocate(colocate_with_variable, self)
|
||||||
|
|
||||||
|
@ -575,17 +566,32 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||||
session)
|
session)
|
||||||
|
|
||||||
def _get_input_workers(self, options):
|
def _get_input_workers(self, options):
|
||||||
prefetch_to_device = self._prefetch_to_device
|
if not options or options.experimental_prefetch_to_device:
|
||||||
if options:
|
|
||||||
prefetch_to_device = options.experimental_prefetch_to_device
|
|
||||||
if prefetch_to_device:
|
|
||||||
return input_lib.InputWorkers(
|
return input_lib.InputWorkers(
|
||||||
tuple(self._device_input_worker_devices.items()))
|
tuple(self._device_input_worker_devices.items()))
|
||||||
else:
|
else:
|
||||||
return input_lib.InputWorkers(
|
return input_lib.InputWorkers(
|
||||||
tuple(self._host_input_worker_devices.items()))
|
tuple(self._host_input_worker_devices.items()))
|
||||||
|
|
||||||
|
def _check_spec(self, element_spec):
|
||||||
|
if isinstance(element_spec, values.PerReplicaSpec):
|
||||||
|
element_spec = element_spec._component_specs # pylint: disable=protected-access
|
||||||
|
specs = nest.flatten_with_joined_string_paths(element_spec)
|
||||||
|
for path, spec in specs:
|
||||||
|
if isinstance(spec, (sparse_tensor.SparseTensorSpec,
|
||||||
|
ragged_tensor.RaggedTensorSpec)):
|
||||||
|
raise ValueError(
|
||||||
|
"Found tensor {} with spec {}. TPUStrategy does not support "
|
||||||
|
"distributed datasets with device prefetch when using sparse or "
|
||||||
|
"ragged tensors. If you indend to use sparse or ragged tensors, "
|
||||||
|
"please pass a tf.distribute.InputOptions object with "
|
||||||
|
"experimental_prefetch_to_device set to False to your dataset "
|
||||||
|
"distribution function.".format(path, type(spec)))
|
||||||
|
|
||||||
def _experimental_distribute_dataset(self, dataset, options):
|
def _experimental_distribute_dataset(self, dataset, options):
|
||||||
|
if options is None or options.experimental_prefetch_to_device:
|
||||||
|
self._check_spec(dataset.element_spec)
|
||||||
|
|
||||||
return input_lib.get_distributed_dataset(
|
return input_lib.get_distributed_dataset(
|
||||||
dataset,
|
dataset,
|
||||||
self._get_input_workers(options),
|
self._get_input_workers(options),
|
||||||
|
@ -603,12 +609,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
|
||||||
input_pipeline_id=i,
|
input_pipeline_id=i,
|
||||||
num_replicas_in_sync=self._num_replicas_in_sync))
|
num_replicas_in_sync=self._num_replicas_in_sync))
|
||||||
|
|
||||||
return input_lib.get_distributed_datasets_from_function(
|
distributed_dataset = input_lib.get_distributed_datasets_from_function(
|
||||||
dataset_fn,
|
dataset_fn,
|
||||||
input_workers,
|
input_workers,
|
||||||
input_contexts,
|
input_contexts,
|
||||||
self._container_strategy())
|
self._container_strategy())
|
||||||
|
|
||||||
|
# We can only check after the dataset_fn is called.
|
||||||
|
if options is None or options.experimental_prefetch_to_device:
|
||||||
|
self._check_spec(distributed_dataset.element_spec)
|
||||||
|
return distributed_dataset
|
||||||
|
|
||||||
def _experimental_distribute_values_from_function(self, value_fn):
|
def _experimental_distribute_values_from_function(self, value_fn):
|
||||||
per_replica_values = []
|
per_replica_values = []
|
||||||
for replica_id in range(self._num_replicas_in_sync):
|
for replica_id in range(self._num_replicas_in_sync):
|
||||||
|
|
|
@ -44,6 +44,7 @@ from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import variables
|
from tensorflow.python.ops import variables
|
||||||
|
from tensorflow.python.ops.ragged import ragged_tensor
|
||||||
from tensorflow.python.platform import flags
|
from tensorflow.python.platform import flags
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
from tensorflow.python.tpu import device_assignment as device_assignment_lib
|
||||||
|
@ -475,9 +476,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
return dataset.map(make_sparse)
|
return dataset.map(make_sparse)
|
||||||
|
|
||||||
strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
|
||||||
dataset = iter(
|
dataset = iter(
|
||||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
strategy.experimental_distribute_datasets_from_function(
|
||||||
|
dataset_fn,
|
||||||
|
distribute_lib.InputOptions(
|
||||||
|
experimental_prefetch_to_device=False)))
|
||||||
|
|
||||||
result = sparse_lookup(dataset)
|
result = sparse_lookup(dataset)
|
||||||
self.assertAllEqual(result,
|
self.assertAllEqual(result,
|
||||||
|
@ -520,9 +523,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
|
||||||
|
|
||||||
return dataset.map(make_sparse)
|
return dataset.map(make_sparse)
|
||||||
|
|
||||||
strategy.extended._set_prefetch_on_host(True) # pylint: disable=protected-access
|
|
||||||
dataset = iter(
|
dataset = iter(
|
||||||
strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
strategy.experimental_distribute_datasets_from_function(
|
||||||
|
dataset_fn,
|
||||||
|
options=distribute_lib.InputOptions(
|
||||||
|
experimental_prefetch_to_device=False)))
|
||||||
|
|
||||||
result = sparse_lookup(dataset)
|
result = sparse_lookup(dataset)
|
||||||
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
|
self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
|
||||||
|
@ -594,5 +599,61 @@ class TPUStrategyDataPrefetchTest(test.TestCase):
|
||||||
dataset_item.values[0].device)
|
dataset_item.values[0].device)
|
||||||
self.assertEqual(dataset_location.device_type, "CPU")
|
self.assertEqual(dataset_location.device_type, "CPU")
|
||||||
|
|
||||||
|
def test_prefetch_to_device_sparse_dataset(self):
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
# Values here aren't important.
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
|
||||||
|
values=[1, 2, 3],
|
||||||
|
dense_shape=[2, 2]))
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
dataset = dataset.batch(strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||||
|
iter(strategy.experimental_distribute_dataset(dataset))
|
||||||
|
|
||||||
|
def test_prefetch_to_device_ragged_dataset(self):
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
# Values here aren't important.
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values=[1, 2, 3],
|
||||||
|
row_splits=[0, 2, 3]))
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
dataset = dataset.batch(strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||||
|
iter(strategy.experimental_distribute_dataset(dataset))
|
||||||
|
|
||||||
|
def test_prefetch_to_device_sparse_dataset_fn(self):
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
# Values here aren't important.
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
|
||||||
|
values=[1, 2, 3],
|
||||||
|
dense_shape=[2, 2]))
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
return dataset.batch(strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||||
|
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||||
|
|
||||||
|
def test_prefetch_to_device_ragged_dataset_fn(self):
|
||||||
|
strategy = get_tpu_strategy()
|
||||||
|
def dataset_fn(ctx):
|
||||||
|
del ctx
|
||||||
|
# Values here aren't important.
|
||||||
|
dataset = dataset_ops.Dataset.from_tensors(
|
||||||
|
ragged_tensor.RaggedTensor.from_row_splits(
|
||||||
|
values=[1, 2, 3],
|
||||||
|
row_splits=[0, 2, 3]))
|
||||||
|
dataset = dataset.repeat()
|
||||||
|
return dataset.batch(strategy.num_replicas_in_sync)
|
||||||
|
|
||||||
|
with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
|
||||||
|
iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
test.main()
|
test.main()
|
||||||
|
|
Loading…
Reference in New Issue