Enabling prefetch_on_device in eager mode for distribution strategies
PiperOrigin-RevId: 226529748
This commit is contained in:
		
							parent
							
								
									181da18675
								
							
						
					
					
						commit
						08a692fe30
					
				@ -188,7 +188,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
 | 
			
		||||
        self.evaluate(elem_on_2)
 | 
			
		||||
 | 
			
		||||
  @test_util.run_v1_only
 | 
			
		||||
  def testMultipleInitializations(self):
 | 
			
		||||
  def testMultipleInitializationsGraph(self):
 | 
			
		||||
    if context.executing_eagerly():
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
@ -209,6 +209,22 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
 | 
			
		||||
        self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
 | 
			
		||||
                                                          elem_on_2]))
 | 
			
		||||
 | 
			
		||||
  @test_util.run_v1_only
 | 
			
		||||
  def testMultipleInitializationsEager(self):
 | 
			
		||||
    if not context.executing_eagerly():
 | 
			
		||||
      return
 | 
			
		||||
 | 
			
		||||
    with ops.device("/cpu:0"):
 | 
			
		||||
      dataset1 = dataset_ops.Dataset.range(1000)
 | 
			
		||||
      dataset2 = dataset_ops.Dataset.range(1000)
 | 
			
		||||
      dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
 | 
			
		||||
 | 
			
		||||
    for _ in range(1000):
 | 
			
		||||
      multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
          dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
 | 
			
		||||
      elem_on_1, elem_on_2 = multi_device_iterator.get_next()
 | 
			
		||||
      self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
 | 
			
		||||
 | 
			
		||||
  @test_util.run_v1_only
 | 
			
		||||
  def testBasicGpu(self):
 | 
			
		||||
    if not test_util.is_gpu_available():
 | 
			
		||||
 | 
			
		||||
@ -160,10 +160,16 @@ class MultiDeviceIterator(object):
 | 
			
		||||
 | 
			
		||||
    # Create the MultiDeviceIterator.
 | 
			
		||||
    with ops.device(self._source_device):
 | 
			
		||||
      # TODO(b/121378567): Get rid of this shared_name hack.
 | 
			
		||||
      shared_name = ""
 | 
			
		||||
      if context.executing_eagerly():
 | 
			
		||||
        # Ensure a unique name when eager execution is enabled to avoid spurious
 | 
			
		||||
        # sharing issues.
 | 
			
		||||
        shared_name += str(ops.uid())
 | 
			
		||||
      self._multi_device_iterator_resource = (
 | 
			
		||||
          gen_dataset_ops.multi_device_iterator(
 | 
			
		||||
              devices=self._devices,
 | 
			
		||||
              shared_name="",
 | 
			
		||||
              shared_name=shared_name,
 | 
			
		||||
              container="",
 | 
			
		||||
              **dataset_ops.flat_structure(dataset)))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,6 @@ from __future__ import print_function
 | 
			
		||||
 | 
			
		||||
import collections
 | 
			
		||||
import contextlib
 | 
			
		||||
import operator
 | 
			
		||||
import weakref
 | 
			
		||||
import six
 | 
			
		||||
 | 
			
		||||
@ -331,7 +330,10 @@ class DistributedDelegate(DistributedValues):
 | 
			
		||||
  def __rmul__(self, o): return o * self.get()
 | 
			
		||||
  def __truediv__(self, o): return self.get() / o
 | 
			
		||||
  def __rtruediv__(self, o): return o / self.get()
 | 
			
		||||
  def __floordiv__(self, o): return self.get() // o
 | 
			
		||||
 | 
			
		||||
  def __floordiv__(self, o):
 | 
			
		||||
    return self.get() // o
 | 
			
		||||
 | 
			
		||||
  def __rfloordiv__(self, o): return o // self.get()
 | 
			
		||||
  def __mod__(self, o): return self.get() % o
 | 
			
		||||
  def __rmod__(self, o): return o % self.get()
 | 
			
		||||
@ -1492,21 +1494,22 @@ class PerReplicaDataset(object):
 | 
			
		||||
    self._input_workers = input_workers
 | 
			
		||||
    self._worker_index = worker_index
 | 
			
		||||
 | 
			
		||||
    # Default to using prefetching in graph mode, unless specified.
 | 
			
		||||
    # TODO(rohanj): Enable prefetching in eager mode.
 | 
			
		||||
    # Default to using prefetching, unless specified.
 | 
			
		||||
    self._prefetch_on_device = prefetch_on_device
 | 
			
		||||
    if self._prefetch_on_device is None:
 | 
			
		||||
      self._prefetch_on_device = not context.executing_eagerly()
 | 
			
		||||
    assert not (self._prefetch_on_device and context.executing_eagerly()), (
 | 
			
		||||
        "Prefetching is only supported in graph mode currently")
 | 
			
		||||
      self._prefetch_on_device = True
 | 
			
		||||
 | 
			
		||||
    self._dataset = dataset
 | 
			
		||||
    if not self._prefetch_on_device:
 | 
			
		||||
      # TODO(priyag): If dropping remainder is not appropriate, find another
 | 
			
		||||
      # approach to distributing the dataset when not possible to divide evenly.
 | 
			
		||||
      # Possibly not an issue when we start using PartitionedDataset.
 | 
			
		||||
      num_replicas = len(input_workers.compute_devices_for_worker(worker_index))
 | 
			
		||||
      self._dataset = dataset.batch(num_replicas, drop_remainder=True)
 | 
			
		||||
      num_replicas = len(
 | 
			
		||||
          self._input_workers.compute_devices_for_worker(self._worker_index))
 | 
			
		||||
      self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
 | 
			
		||||
    else:
 | 
			
		||||
      self._replica_devices = self._input_workers.compute_devices_for_worker(
 | 
			
		||||
          self._worker_index)
 | 
			
		||||
 | 
			
		||||
  def make_one_shot_iterator(self):
 | 
			
		||||
    """Get a one time use iterator for the distributed PerReplicaDataset."""
 | 
			
		||||
@ -1514,13 +1517,16 @@ class PerReplicaDataset(object):
 | 
			
		||||
    if not context.executing_eagerly():
 | 
			
		||||
      raise ValueError("Cannot create a one shot iterator. Please use "
 | 
			
		||||
                       "`make_initializable_iterator()` instead.")
 | 
			
		||||
    # Eager mode prefetching would error out in constructor. Only remaining
 | 
			
		||||
    # case is non-prefetching in eager mode. We delegate to
 | 
			
		||||
    # PerReplicaDataIterator to handle that case.
 | 
			
		||||
    dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
 | 
			
		||||
    if self._prefetch_on_device:
 | 
			
		||||
      dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
          self._dataset, self._replica_devices)
 | 
			
		||||
    else:
 | 
			
		||||
      dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
 | 
			
		||||
    return PerReplicaDataIterator(
 | 
			
		||||
        dataset_iterator, self._input_workers, self._worker_index,
 | 
			
		||||
        prefetch_on_device=False)
 | 
			
		||||
        dataset_iterator,
 | 
			
		||||
        self._input_workers,
 | 
			
		||||
        self._worker_index,
 | 
			
		||||
        prefetch_on_device=self._prefetch_on_device)
 | 
			
		||||
 | 
			
		||||
  def make_initializable_iterator(self):
 | 
			
		||||
    """Get an initializable iterator for the distributed PerReplicaDataset."""
 | 
			
		||||
@ -1530,10 +1536,8 @@ class PerReplicaDataset(object):
 | 
			
		||||
      raise ValueError("Cannot create initializable iterator in Eager mode. "
 | 
			
		||||
                       "Please use `make_one_shot_iterator` instead.")
 | 
			
		||||
    if self._prefetch_on_device:
 | 
			
		||||
      replica_devices = self._input_workers.compute_devices_for_worker(
 | 
			
		||||
          self._worker_index)
 | 
			
		||||
      dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
          self._dataset, replica_devices)
 | 
			
		||||
          self._dataset, self._replica_devices)
 | 
			
		||||
    else:
 | 
			
		||||
      dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
 | 
			
		||||
    return PerReplicaDataIterator(
 | 
			
		||||
@ -1690,13 +1694,9 @@ class InputIteratorImpl(InputIterator):
 | 
			
		||||
 | 
			
		||||
    self._iterators = iterators
 | 
			
		||||
    self._input_workers = input_workers
 | 
			
		||||
    self._is_eager = context.executing_eagerly()
 | 
			
		||||
 | 
			
		||||
  def get_next(self, name=None):
 | 
			
		||||
    """Returns the next input from the iterator for all replicas."""
 | 
			
		||||
    assert self._is_eager == context.executing_eagerly(), (
 | 
			
		||||
        "Iterator should be created and used in same execution mode.")
 | 
			
		||||
 | 
			
		||||
    replicas = []
 | 
			
		||||
    for i, worker in enumerate(self._input_workers.worker_devices):
 | 
			
		||||
      if name is not None:
 | 
			
		||||
@ -1716,9 +1716,6 @@ class InputIteratorImpl(InputIterator):
 | 
			
		||||
    Returns:
 | 
			
		||||
      A list of any initializer ops that should be run.
 | 
			
		||||
    """
 | 
			
		||||
    assert self._is_eager == context.executing_eagerly(), (
 | 
			
		||||
        "Iterator should be created and used in same execution mode.")
 | 
			
		||||
 | 
			
		||||
    init_ops = []
 | 
			
		||||
    for it in self._iterators:
 | 
			
		||||
      init_ops.extend(it.initialize())
 | 
			
		||||
@ -1842,7 +1839,7 @@ class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
    """Create iterator for the `dataset` to fetch data to worker's `devices` .
 | 
			
		||||
 | 
			
		||||
    `MultiDeviceIterator` is used to prefetch input to the devices on the
 | 
			
		||||
    given worker. `MultiDeviceIterator` doesn't work in eager mode yet.
 | 
			
		||||
    given worker.
 | 
			
		||||
 | 
			
		||||
    Args:
 | 
			
		||||
      dataset: A `tf.data.Dataset` instance.
 | 
			
		||||
@ -1852,39 +1849,19 @@ class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
    self._dataset = dataset
 | 
			
		||||
    self._worker = worker
 | 
			
		||||
    self._devices = devices
 | 
			
		||||
    self._is_eager = context.executing_eagerly()
 | 
			
		||||
    self._make_iterator()
 | 
			
		||||
 | 
			
		||||
  def _make_iterator(self):
 | 
			
		||||
    """Make appropriate iterator on the dataset."""
 | 
			
		||||
    with ops.device(self._worker):
 | 
			
		||||
      if self._is_eager:
 | 
			
		||||
        # TODO(rohanj): Enable prefetching in eager mode.
 | 
			
		||||
        # TODO(priyag): Measure the performance of this approach vs calling
 | 
			
		||||
        # get_next on the original dataset N times.
 | 
			
		||||
        dataset = self._dataset.batch(len(self._devices), drop_remainder=True)
 | 
			
		||||
        iterator = dataset_ops.make_one_shot_iterator(dataset)
 | 
			
		||||
      else:
 | 
			
		||||
        iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
            self._dataset, self._devices)
 | 
			
		||||
    self._iterator = iterator
 | 
			
		||||
      self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
 | 
			
		||||
          self._dataset, self._devices)
 | 
			
		||||
 | 
			
		||||
  def get_next_as_list(self, name=None):
 | 
			
		||||
    """Get next element from the underlying iterator."""
 | 
			
		||||
    del name
 | 
			
		||||
    with ops.device(self._worker):
 | 
			
		||||
      if self._is_eager:
 | 
			
		||||
        # Batched dataset case.
 | 
			
		||||
        batch = self._iterator.get_next(name=name)
 | 
			
		||||
        data_list = []
 | 
			
		||||
        for i, d in enumerate(self._devices):
 | 
			
		||||
          v = nest.map_structure(operator.itemgetter(i), batch)
 | 
			
		||||
          with ops.device(d):
 | 
			
		||||
            v = nest.map_structure(array_ops.identity, v)
 | 
			
		||||
          data_list.append(v)
 | 
			
		||||
      else:
 | 
			
		||||
        # MultiDeviceIterator case.
 | 
			
		||||
        data_list = self._iterator.get_next()
 | 
			
		||||
 | 
			
		||||
      data_list = self._iterator.get_next()
 | 
			
		||||
      return data_list
 | 
			
		||||
 | 
			
		||||
  def initialize(self):
 | 
			
		||||
@ -1897,7 +1874,7 @@ class _SingleWorkerDatasetIterator(object):
 | 
			
		||||
    Returns:
 | 
			
		||||
      A list of any initializer ops that should be run.
 | 
			
		||||
    """
 | 
			
		||||
    if self._is_eager:
 | 
			
		||||
    if context.executing_eagerly():
 | 
			
		||||
      self._make_iterator()
 | 
			
		||||
      return []
 | 
			
		||||
    else:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user