Enabling prefetch_on_device in eager mode for distribution strategies
PiperOrigin-RevId: 226529748
This commit is contained in:
parent
181da18675
commit
08a692fe30
@ -188,7 +188,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
self.evaluate(elem_on_2)
|
||||
|
||||
@test_util.run_v1_only
|
||||
def testMultipleInitializations(self):
|
||||
def testMultipleInitializationsGraph(self):
|
||||
if context.executing_eagerly():
|
||||
return
|
||||
|
||||
@ -209,6 +209,22 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
||||
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
|
||||
elem_on_2]))
|
||||
|
||||
@test_util.run_v1_only
|
||||
def testMultipleInitializationsEager(self):
|
||||
if not context.executing_eagerly():
|
||||
return
|
||||
|
||||
with ops.device("/cpu:0"):
|
||||
dataset1 = dataset_ops.Dataset.range(1000)
|
||||
dataset2 = dataset_ops.Dataset.range(1000)
|
||||
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||
|
||||
for _ in range(1000):
|
||||
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
|
||||
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||
self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
|
||||
|
||||
@test_util.run_v1_only
|
||||
def testBasicGpu(self):
|
||||
if not test_util.is_gpu_available():
|
||||
|
@ -160,10 +160,16 @@ class MultiDeviceIterator(object):
|
||||
|
||||
# Create the MultiDeviceIterator.
|
||||
with ops.device(self._source_device):
|
||||
# TODO(b/121378567): Get rid of this shared_name hack.
|
||||
shared_name = ""
|
||||
if context.executing_eagerly():
|
||||
# Ensure a unique name when eager execution is enabled to avoid spurious
|
||||
# sharing issues.
|
||||
shared_name += str(ops.uid())
|
||||
self._multi_device_iterator_resource = (
|
||||
gen_dataset_ops.multi_device_iterator(
|
||||
devices=self._devices,
|
||||
shared_name="",
|
||||
shared_name=shared_name,
|
||||
container="",
|
||||
**dataset_ops.flat_structure(dataset)))
|
||||
|
||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import contextlib
|
||||
import operator
|
||||
import weakref
|
||||
import six
|
||||
|
||||
@ -331,7 +330,10 @@ class DistributedDelegate(DistributedValues):
|
||||
def __rmul__(self, o): return o * self.get()
|
||||
def __truediv__(self, o): return self.get() / o
|
||||
def __rtruediv__(self, o): return o / self.get()
|
||||
def __floordiv__(self, o): return self.get() // o
|
||||
|
||||
def __floordiv__(self, o):
|
||||
return self.get() // o
|
||||
|
||||
def __rfloordiv__(self, o): return o // self.get()
|
||||
def __mod__(self, o): return self.get() % o
|
||||
def __rmod__(self, o): return o % self.get()
|
||||
@ -1492,21 +1494,22 @@ class PerReplicaDataset(object):
|
||||
self._input_workers = input_workers
|
||||
self._worker_index = worker_index
|
||||
|
||||
# Default to using prefetching in graph mode, unless specified.
|
||||
# TODO(rohanj): Enable prefetching in eager mode.
|
||||
# Default to using prefetching, unless specified.
|
||||
self._prefetch_on_device = prefetch_on_device
|
||||
if self._prefetch_on_device is None:
|
||||
self._prefetch_on_device = not context.executing_eagerly()
|
||||
assert not (self._prefetch_on_device and context.executing_eagerly()), (
|
||||
"Prefetching is only supported in graph mode currently")
|
||||
self._prefetch_on_device = True
|
||||
|
||||
self._dataset = dataset
|
||||
if not self._prefetch_on_device:
|
||||
# TODO(priyag): If dropping remainder is not appropriate, find another
|
||||
# approach to distributing the dataset when not possible to divide evenly.
|
||||
# Possibly not an issue when we start using PartitionedDataset.
|
||||
num_replicas = len(input_workers.compute_devices_for_worker(worker_index))
|
||||
self._dataset = dataset.batch(num_replicas, drop_remainder=True)
|
||||
num_replicas = len(
|
||||
self._input_workers.compute_devices_for_worker(self._worker_index))
|
||||
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
|
||||
else:
|
||||
self._replica_devices = self._input_workers.compute_devices_for_worker(
|
||||
self._worker_index)
|
||||
|
||||
def make_one_shot_iterator(self):
|
||||
"""Get a one time use iterator for the distributed PerReplicaDataset."""
|
||||
@ -1514,13 +1517,16 @@ class PerReplicaDataset(object):
|
||||
if not context.executing_eagerly():
|
||||
raise ValueError("Cannot create a one shot iterator. Please use "
|
||||
"`make_initializable_iterator()` instead.")
|
||||
# Eager mode prefetching would error out in constructor. Only remaining
|
||||
# case is non-prefetching in eager mode. We delegate to
|
||||
# PerReplicaDataIterator to handle that case.
|
||||
if self._prefetch_on_device:
|
||||
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, self._replica_devices)
|
||||
else:
|
||||
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
|
||||
return PerReplicaDataIterator(
|
||||
dataset_iterator, self._input_workers, self._worker_index,
|
||||
prefetch_on_device=False)
|
||||
dataset_iterator,
|
||||
self._input_workers,
|
||||
self._worker_index,
|
||||
prefetch_on_device=self._prefetch_on_device)
|
||||
|
||||
def make_initializable_iterator(self):
|
||||
"""Get an initializable iterator for the distributed PerReplicaDataset."""
|
||||
@ -1530,10 +1536,8 @@ class PerReplicaDataset(object):
|
||||
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
||||
"Please use `make_one_shot_iterator` instead.")
|
||||
if self._prefetch_on_device:
|
||||
replica_devices = self._input_workers.compute_devices_for_worker(
|
||||
self._worker_index)
|
||||
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, replica_devices)
|
||||
self._dataset, self._replica_devices)
|
||||
else:
|
||||
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
|
||||
return PerReplicaDataIterator(
|
||||
@ -1690,13 +1694,9 @@ class InputIteratorImpl(InputIterator):
|
||||
|
||||
self._iterators = iterators
|
||||
self._input_workers = input_workers
|
||||
self._is_eager = context.executing_eagerly()
|
||||
|
||||
def get_next(self, name=None):
|
||||
"""Returns the next input from the iterator for all replicas."""
|
||||
assert self._is_eager == context.executing_eagerly(), (
|
||||
"Iterator should be created and used in same execution mode.")
|
||||
|
||||
replicas = []
|
||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||
if name is not None:
|
||||
@ -1716,9 +1716,6 @@ class InputIteratorImpl(InputIterator):
|
||||
Returns:
|
||||
A list of any initializer ops that should be run.
|
||||
"""
|
||||
assert self._is_eager == context.executing_eagerly(), (
|
||||
"Iterator should be created and used in same execution mode.")
|
||||
|
||||
init_ops = []
|
||||
for it in self._iterators:
|
||||
init_ops.extend(it.initialize())
|
||||
@ -1842,7 +1839,7 @@ class _SingleWorkerDatasetIterator(object):
|
||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||
|
||||
`MultiDeviceIterator` is used to prefetch input to the devices on the
|
||||
given worker. `MultiDeviceIterator` doesn't work in eager mode yet.
|
||||
given worker.
|
||||
|
||||
Args:
|
||||
dataset: A `tf.data.Dataset` instance.
|
||||
@ -1852,39 +1849,19 @@ class _SingleWorkerDatasetIterator(object):
|
||||
self._dataset = dataset
|
||||
self._worker = worker
|
||||
self._devices = devices
|
||||
self._is_eager = context.executing_eagerly()
|
||||
self._make_iterator()
|
||||
|
||||
def _make_iterator(self):
|
||||
"""Make appropriate iterator on the dataset."""
|
||||
with ops.device(self._worker):
|
||||
if self._is_eager:
|
||||
# TODO(rohanj): Enable prefetching in eager mode.
|
||||
# TODO(priyag): Measure the performance of this approach vs calling
|
||||
# get_next on the original dataset N times.
|
||||
dataset = self._dataset.batch(len(self._devices), drop_remainder=True)
|
||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
||||
else:
|
||||
iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||
self._dataset, self._devices)
|
||||
self._iterator = iterator
|
||||
|
||||
def get_next_as_list(self, name=None):
|
||||
"""Get next element from the underlying iterator."""
|
||||
del name
|
||||
with ops.device(self._worker):
|
||||
if self._is_eager:
|
||||
# Batched dataset case.
|
||||
batch = self._iterator.get_next(name=name)
|
||||
data_list = []
|
||||
for i, d in enumerate(self._devices):
|
||||
v = nest.map_structure(operator.itemgetter(i), batch)
|
||||
with ops.device(d):
|
||||
v = nest.map_structure(array_ops.identity, v)
|
||||
data_list.append(v)
|
||||
else:
|
||||
# MultiDeviceIterator case.
|
||||
data_list = self._iterator.get_next()
|
||||
|
||||
return data_list
|
||||
|
||||
def initialize(self):
|
||||
@ -1897,7 +1874,7 @@ class _SingleWorkerDatasetIterator(object):
|
||||
Returns:
|
||||
A list of any initializer ops that should be run.
|
||||
"""
|
||||
if self._is_eager:
|
||||
if context.executing_eagerly():
|
||||
self._make_iterator()
|
||||
return []
|
||||
else:
|
||||
|
Loading…
Reference in New Issue
Block a user