Clean up _set_prefetch_on_host method from TPUStrategy.
PiperOrigin-RevId: 317672388 Change-Id: Icb03401a792df4d071659f101ba06edf59010577
This commit is contained in:
		
							parent
							
								
									22f6939be9
								
							
						
					
					
						commit
						eb32bf1ae4
					
				@ -454,10 +454,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
 | 
				
			|||||||
      ))
 | 
					      ))
 | 
				
			||||||
  def testDistributeDatasetHostPrefetch(self, distribution):
 | 
					  def testDistributeDatasetHostPrefetch(self, distribution):
 | 
				
			||||||
    data = [5., 6., 7., 8.]
 | 
					    data = [5., 6., 7., 8.]
 | 
				
			||||||
    distribution.extended._set_prefetch_on_host(True)  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
    input_iterator = iter(
 | 
					    input_iterator = iter(
 | 
				
			||||||
        distribution.experimental_distribute_dataset(
 | 
					        distribution.experimental_distribute_dataset(
 | 
				
			||||||
            get_dataset_from_tensor_slices(data).batch(2)))
 | 
					            get_dataset_from_tensor_slices(data).batch(2),
 | 
				
			||||||
 | 
					            distribute_lib.InputOptions(
 | 
				
			||||||
 | 
					                experimental_prefetch_to_device=False)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    local_results = distribution.experimental_local_results(
 | 
					    local_results = distribution.experimental_local_results(
 | 
				
			||||||
        input_iterator.get_next())
 | 
					        input_iterator.get_next())
 | 
				
			||||||
@ -473,10 +474,11 @@ class InputIterationTest(test.TestCase, parameterized.TestCase,
 | 
				
			|||||||
      ))
 | 
					      ))
 | 
				
			||||||
  def testDistributeDatasetFunctionHostPrefetch(self, distribution):
 | 
					  def testDistributeDatasetFunctionHostPrefetch(self, distribution):
 | 
				
			||||||
    data = [5., 6., 7., 8.]
 | 
					    data = [5., 6., 7., 8.]
 | 
				
			||||||
    distribution.extended._set_prefetch_on_host(True)  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
    input_iterator = iter(
 | 
					    input_iterator = iter(
 | 
				
			||||||
        distribution.experimental_distribute_datasets_from_function(
 | 
					        distribution.experimental_distribute_datasets_from_function(
 | 
				
			||||||
            lambda _: get_dataset_from_tensor_slices(data)))
 | 
					            lambda _: get_dataset_from_tensor_slices(data),
 | 
				
			||||||
 | 
					            distribute_lib.InputOptions(
 | 
				
			||||||
 | 
					                experimental_prefetch_to_device=False)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    local_results = distribution.experimental_local_results(
 | 
					    local_results = distribution.experimental_local_results(
 | 
				
			||||||
        input_iterator.get_next())
 | 
					        input_iterator.get_next())
 | 
				
			||||||
 | 
				
			|||||||
@ -47,12 +47,14 @@ from tensorflow.python.framework import constant_op
 | 
				
			|||||||
from tensorflow.python.framework import device_spec
 | 
					from tensorflow.python.framework import device_spec
 | 
				
			||||||
from tensorflow.python.framework import dtypes
 | 
					from tensorflow.python.framework import dtypes
 | 
				
			||||||
from tensorflow.python.framework import ops
 | 
					from tensorflow.python.framework import ops
 | 
				
			||||||
 | 
					from tensorflow.python.framework import sparse_tensor
 | 
				
			||||||
from tensorflow.python.framework import tensor_shape
 | 
					from tensorflow.python.framework import tensor_shape
 | 
				
			||||||
from tensorflow.python.framework import tensor_util
 | 
					from tensorflow.python.framework import tensor_util
 | 
				
			||||||
from tensorflow.python.ops import array_ops
 | 
					from tensorflow.python.ops import array_ops
 | 
				
			||||||
from tensorflow.python.ops import control_flow_ops
 | 
					from tensorflow.python.ops import control_flow_ops
 | 
				
			||||||
from tensorflow.python.ops import math_ops
 | 
					from tensorflow.python.ops import math_ops
 | 
				
			||||||
from tensorflow.python.ops import resource_variable_ops
 | 
					from tensorflow.python.ops import resource_variable_ops
 | 
				
			||||||
 | 
					from tensorflow.python.ops.ragged import ragged_tensor
 | 
				
			||||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
 | 
					from tensorflow.python.tpu import device_assignment as device_assignment_lib  # pylint: disable=unused-import
 | 
				
			||||||
from tensorflow.python.tpu import tpu
 | 
					from tensorflow.python.tpu import tpu
 | 
				
			||||||
from tensorflow.python.tpu import tpu_strategy_util
 | 
					from tensorflow.python.tpu import tpu_strategy_util
 | 
				
			||||||
@ -515,7 +517,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 | 
				
			|||||||
    self._require_static_shapes = True
 | 
					    self._require_static_shapes = True
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self.experimental_enable_get_next_as_optional = True
 | 
					    self.experimental_enable_get_next_as_optional = True
 | 
				
			||||||
    self._prefetch_to_device = True
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    self._logical_device_stack = [0]
 | 
					    self._logical_device_stack = [0]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -527,16 +528,6 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 | 
				
			|||||||
          context.async_wait()
 | 
					          context.async_wait()
 | 
				
			||||||
      atexit.register(async_wait)
 | 
					      atexit.register(async_wait)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  # TODO(bfontain): Remove once a proper dataset API exists for prefetching
 | 
					 | 
				
			||||||
  # a dataset to multiple devices exists.
 | 
					 | 
				
			||||||
  # If value is true, this forces prefetch of data to the host's memeory rather
 | 
					 | 
				
			||||||
  # than the individual TPU device's memory. This is needed when using for TPU
 | 
					 | 
				
			||||||
  # Embeddings as a) sparse tensors cannot be prefetched to the TPU device
 | 
					 | 
				
			||||||
  # memory and b) TPU Embedding enqueue operation are CPU ops and this avoids
 | 
					 | 
				
			||||||
  # a copy back to the host for dense tensors
 | 
					 | 
				
			||||||
  def _set_prefetch_on_host(self, value):
 | 
					 | 
				
			||||||
    self._prefetch_to_device = not value
 | 
					 | 
				
			||||||
 | 
					 | 
				
			||||||
  def _validate_colocate_with_variable(self, colocate_with_variable):
 | 
					  def _validate_colocate_with_variable(self, colocate_with_variable):
 | 
				
			||||||
    distribute_utils. validate_colocate(colocate_with_variable, self)
 | 
					    distribute_utils. validate_colocate(colocate_with_variable, self)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
@ -575,17 +566,32 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 | 
				
			|||||||
        session)
 | 
					        session)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _get_input_workers(self, options):
 | 
					  def _get_input_workers(self, options):
 | 
				
			||||||
    prefetch_to_device = self._prefetch_to_device
 | 
					    if not options or options.experimental_prefetch_to_device:
 | 
				
			||||||
    if options:
 | 
					 | 
				
			||||||
      prefetch_to_device = options.experimental_prefetch_to_device
 | 
					 | 
				
			||||||
    if prefetch_to_device:
 | 
					 | 
				
			||||||
      return input_lib.InputWorkers(
 | 
					      return input_lib.InputWorkers(
 | 
				
			||||||
          tuple(self._device_input_worker_devices.items()))
 | 
					          tuple(self._device_input_worker_devices.items()))
 | 
				
			||||||
    else:
 | 
					    else:
 | 
				
			||||||
      return input_lib.InputWorkers(
 | 
					      return input_lib.InputWorkers(
 | 
				
			||||||
          tuple(self._host_input_worker_devices.items()))
 | 
					          tuple(self._host_input_worker_devices.items()))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def _check_spec(self, element_spec):
 | 
				
			||||||
 | 
					    if isinstance(element_spec, values.PerReplicaSpec):
 | 
				
			||||||
 | 
					      element_spec = element_spec._component_specs  # pylint: disable=protected-access
 | 
				
			||||||
 | 
					    specs = nest.flatten_with_joined_string_paths(element_spec)
 | 
				
			||||||
 | 
					    for path, spec in specs:
 | 
				
			||||||
 | 
					      if isinstance(spec, (sparse_tensor.SparseTensorSpec,
 | 
				
			||||||
 | 
					                           ragged_tensor.RaggedTensorSpec)):
 | 
				
			||||||
 | 
					        raise ValueError(
 | 
				
			||||||
 | 
					            "Found tensor {} with spec {}. TPUStrategy does not support "
 | 
				
			||||||
 | 
					            "distributed datasets with device prefetch when using sparse or "
 | 
				
			||||||
 | 
					            "ragged tensors. If you indend to use sparse or ragged tensors, "
 | 
				
			||||||
 | 
					            "please pass a tf.distribute.InputOptions object with "
 | 
				
			||||||
 | 
					            "experimental_prefetch_to_device set to False to your dataset "
 | 
				
			||||||
 | 
					            "distribution function.".format(path, type(spec)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _experimental_distribute_dataset(self, dataset, options):
 | 
					  def _experimental_distribute_dataset(self, dataset, options):
 | 
				
			||||||
 | 
					    if options is None or options.experimental_prefetch_to_device:
 | 
				
			||||||
 | 
					      self._check_spec(dataset.element_spec)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return input_lib.get_distributed_dataset(
 | 
					    return input_lib.get_distributed_dataset(
 | 
				
			||||||
        dataset,
 | 
					        dataset,
 | 
				
			||||||
        self._get_input_workers(options),
 | 
					        self._get_input_workers(options),
 | 
				
			||||||
@ -603,12 +609,17 @@ class TPUExtended(distribute_lib.StrategyExtendedV1):
 | 
				
			|||||||
          input_pipeline_id=i,
 | 
					          input_pipeline_id=i,
 | 
				
			||||||
          num_replicas_in_sync=self._num_replicas_in_sync))
 | 
					          num_replicas_in_sync=self._num_replicas_in_sync))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    return input_lib.get_distributed_datasets_from_function(
 | 
					    distributed_dataset = input_lib.get_distributed_datasets_from_function(
 | 
				
			||||||
        dataset_fn,
 | 
					        dataset_fn,
 | 
				
			||||||
        input_workers,
 | 
					        input_workers,
 | 
				
			||||||
        input_contexts,
 | 
					        input_contexts,
 | 
				
			||||||
        self._container_strategy())
 | 
					        self._container_strategy())
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    # We can only check after the dataset_fn is called.
 | 
				
			||||||
 | 
					    if options is None or options.experimental_prefetch_to_device:
 | 
				
			||||||
 | 
					      self._check_spec(distributed_dataset.element_spec)
 | 
				
			||||||
 | 
					    return distributed_dataset
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  def _experimental_distribute_values_from_function(self, value_fn):
 | 
					  def _experimental_distribute_values_from_function(self, value_fn):
 | 
				
			||||||
    per_replica_values = []
 | 
					    per_replica_values = []
 | 
				
			||||||
    for replica_id in range(self._num_replicas_in_sync):
 | 
					    for replica_id in range(self._num_replicas_in_sync):
 | 
				
			||||||
 | 
				
			|||||||
@ -44,6 +44,7 @@ from tensorflow.python.ops import embedding_ops
 | 
				
			|||||||
from tensorflow.python.ops import math_ops
 | 
					from tensorflow.python.ops import math_ops
 | 
				
			||||||
from tensorflow.python.ops import random_ops
 | 
					from tensorflow.python.ops import random_ops
 | 
				
			||||||
from tensorflow.python.ops import variables
 | 
					from tensorflow.python.ops import variables
 | 
				
			||||||
 | 
					from tensorflow.python.ops.ragged import ragged_tensor
 | 
				
			||||||
from tensorflow.python.platform import flags
 | 
					from tensorflow.python.platform import flags
 | 
				
			||||||
from tensorflow.python.platform import tf_logging as logging
 | 
					from tensorflow.python.platform import tf_logging as logging
 | 
				
			||||||
from tensorflow.python.tpu import device_assignment as device_assignment_lib
 | 
					from tensorflow.python.tpu import device_assignment as device_assignment_lib
 | 
				
			||||||
@ -475,9 +476,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
      return dataset.map(make_sparse)
 | 
					      return dataset.map(make_sparse)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    strategy.extended._set_prefetch_on_host(True)  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
    dataset = iter(
 | 
					    dataset = iter(
 | 
				
			||||||
        strategy.experimental_distribute_datasets_from_function(dataset_fn))
 | 
					        strategy.experimental_distribute_datasets_from_function(
 | 
				
			||||||
 | 
					            dataset_fn,
 | 
				
			||||||
 | 
					            distribute_lib.InputOptions(
 | 
				
			||||||
 | 
					                experimental_prefetch_to_device=False)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = sparse_lookup(dataset)
 | 
					    result = sparse_lookup(dataset)
 | 
				
			||||||
    self.assertAllEqual(result,
 | 
					    self.assertAllEqual(result,
 | 
				
			||||||
@ -520,9 +523,11 @@ class TPUStrategyTest(test.TestCase, parameterized.TestCase):
 | 
				
			|||||||
 | 
					
 | 
				
			||||||
      return dataset.map(make_sparse)
 | 
					      return dataset.map(make_sparse)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    strategy.extended._set_prefetch_on_host(True)  # pylint: disable=protected-access
 | 
					 | 
				
			||||||
    dataset = iter(
 | 
					    dataset = iter(
 | 
				
			||||||
        strategy.experimental_distribute_datasets_from_function(dataset_fn))
 | 
					        strategy.experimental_distribute_datasets_from_function(
 | 
				
			||||||
 | 
					            dataset_fn,
 | 
				
			||||||
 | 
					            options=distribute_lib.InputOptions(
 | 
				
			||||||
 | 
					                experimental_prefetch_to_device=False)))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
    result = sparse_lookup(dataset)
 | 
					    result = sparse_lookup(dataset)
 | 
				
			||||||
    self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
 | 
					    self.assertAllEqual(result, [[0.0, 2.0], [1.5, 5.0]])
 | 
				
			||||||
@ -594,5 +599,61 @@ class TPUStrategyDataPrefetchTest(test.TestCase):
 | 
				
			|||||||
        dataset_item.values[0].device)
 | 
					        dataset_item.values[0].device)
 | 
				
			||||||
    self.assertEqual(dataset_location.device_type, "CPU")
 | 
					    self.assertEqual(dataset_location.device_type, "CPU")
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_prefetch_to_device_sparse_dataset(self):
 | 
				
			||||||
 | 
					    strategy = get_tpu_strategy()
 | 
				
			||||||
 | 
					    # Values here aren't important.
 | 
				
			||||||
 | 
					    dataset = dataset_ops.Dataset.from_tensors(
 | 
				
			||||||
 | 
					        sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
 | 
				
			||||||
 | 
					                                   values=[1, 2, 3],
 | 
				
			||||||
 | 
					                                   dense_shape=[2, 2]))
 | 
				
			||||||
 | 
					    dataset = dataset.repeat()
 | 
				
			||||||
 | 
					    dataset = dataset.batch(strategy.num_replicas_in_sync)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
 | 
				
			||||||
 | 
					      iter(strategy.experimental_distribute_dataset(dataset))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_prefetch_to_device_ragged_dataset(self):
 | 
				
			||||||
 | 
					    strategy = get_tpu_strategy()
 | 
				
			||||||
 | 
					    # Values here aren't important.
 | 
				
			||||||
 | 
					    dataset = dataset_ops.Dataset.from_tensors(
 | 
				
			||||||
 | 
					        ragged_tensor.RaggedTensor.from_row_splits(
 | 
				
			||||||
 | 
					            values=[1, 2, 3],
 | 
				
			||||||
 | 
					            row_splits=[0, 2, 3]))
 | 
				
			||||||
 | 
					    dataset = dataset.repeat()
 | 
				
			||||||
 | 
					    dataset = dataset.batch(strategy.num_replicas_in_sync)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
 | 
				
			||||||
 | 
					      iter(strategy.experimental_distribute_dataset(dataset))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_prefetch_to_device_sparse_dataset_fn(self):
 | 
				
			||||||
 | 
					    strategy = get_tpu_strategy()
 | 
				
			||||||
 | 
					    def dataset_fn(ctx):
 | 
				
			||||||
 | 
					      del ctx
 | 
				
			||||||
 | 
					      # Values here aren't important.
 | 
				
			||||||
 | 
					      dataset = dataset_ops.Dataset.from_tensors(
 | 
				
			||||||
 | 
					          sparse_tensor.SparseTensor(indices=[[0, 0], [0, 1], [1, 0]],
 | 
				
			||||||
 | 
					                                     values=[1, 2, 3],
 | 
				
			||||||
 | 
					                                     dense_shape=[2, 2]))
 | 
				
			||||||
 | 
					      dataset = dataset.repeat()
 | 
				
			||||||
 | 
					      return dataset.batch(strategy.num_replicas_in_sync)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
 | 
				
			||||||
 | 
					      iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  def test_prefetch_to_device_ragged_dataset_fn(self):
 | 
				
			||||||
 | 
					    strategy = get_tpu_strategy()
 | 
				
			||||||
 | 
					    def dataset_fn(ctx):
 | 
				
			||||||
 | 
					      del ctx
 | 
				
			||||||
 | 
					      # Values here aren't important.
 | 
				
			||||||
 | 
					      dataset = dataset_ops.Dataset.from_tensors(
 | 
				
			||||||
 | 
					          ragged_tensor.RaggedTensor.from_row_splits(
 | 
				
			||||||
 | 
					              values=[1, 2, 3],
 | 
				
			||||||
 | 
					              row_splits=[0, 2, 3]))
 | 
				
			||||||
 | 
					      dataset = dataset.repeat()
 | 
				
			||||||
 | 
					      return dataset.batch(strategy.num_replicas_in_sync)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    with self.assertRaisesRegex(ValueError, "TPUStrategy does not support"):
 | 
				
			||||||
 | 
					      iter(strategy.experimental_distribute_datasets_from_function(dataset_fn))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
if __name__ == "__main__":
 | 
					if __name__ == "__main__":
 | 
				
			||||||
  test.main()
 | 
					  test.main()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user