Enabling prefetch_on_device in eager mode for distribution strategies
PiperOrigin-RevId: 226529748
This commit is contained in:
parent
181da18675
commit
08a692fe30
@ -188,7 +188,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
self.evaluate(elem_on_2)
|
self.evaluate(elem_on_2)
|
||||||
|
|
||||||
@test_util.run_v1_only
|
@test_util.run_v1_only
|
||||||
def testMultipleInitializations(self):
|
def testMultipleInitializationsGraph(self):
|
||||||
if context.executing_eagerly():
|
if context.executing_eagerly():
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -209,6 +209,22 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
|
|||||||
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
|
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
|
||||||
elem_on_2]))
|
elem_on_2]))
|
||||||
|
|
||||||
|
@test_util.run_v1_only
|
||||||
|
def testMultipleInitializationsEager(self):
|
||||||
|
if not context.executing_eagerly():
|
||||||
|
return
|
||||||
|
|
||||||
|
with ops.device("/cpu:0"):
|
||||||
|
dataset1 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset2 = dataset_ops.Dataset.range(1000)
|
||||||
|
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
|
||||||
|
|
||||||
|
for _ in range(1000):
|
||||||
|
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
|
dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
|
||||||
|
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
|
||||||
|
self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
|
||||||
|
|
||||||
@test_util.run_v1_only
|
@test_util.run_v1_only
|
||||||
def testBasicGpu(self):
|
def testBasicGpu(self):
|
||||||
if not test_util.is_gpu_available():
|
if not test_util.is_gpu_available():
|
||||||
|
@ -160,10 +160,16 @@ class MultiDeviceIterator(object):
|
|||||||
|
|
||||||
# Create the MultiDeviceIterator.
|
# Create the MultiDeviceIterator.
|
||||||
with ops.device(self._source_device):
|
with ops.device(self._source_device):
|
||||||
|
# TODO(b/121378567): Get rid of this shared_name hack.
|
||||||
|
shared_name = ""
|
||||||
|
if context.executing_eagerly():
|
||||||
|
# Ensure a unique name when eager execution is enabled to avoid spurious
|
||||||
|
# sharing issues.
|
||||||
|
shared_name += str(ops.uid())
|
||||||
self._multi_device_iterator_resource = (
|
self._multi_device_iterator_resource = (
|
||||||
gen_dataset_ops.multi_device_iterator(
|
gen_dataset_ops.multi_device_iterator(
|
||||||
devices=self._devices,
|
devices=self._devices,
|
||||||
shared_name="",
|
shared_name=shared_name,
|
||||||
container="",
|
container="",
|
||||||
**dataset_ops.flat_structure(dataset)))
|
**dataset_ops.flat_structure(dataset)))
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import collections
|
import collections
|
||||||
import contextlib
|
import contextlib
|
||||||
import operator
|
|
||||||
import weakref
|
import weakref
|
||||||
import six
|
import six
|
||||||
|
|
||||||
@ -331,7 +330,10 @@ class DistributedDelegate(DistributedValues):
|
|||||||
def __rmul__(self, o): return o * self.get()
|
def __rmul__(self, o): return o * self.get()
|
||||||
def __truediv__(self, o): return self.get() / o
|
def __truediv__(self, o): return self.get() / o
|
||||||
def __rtruediv__(self, o): return o / self.get()
|
def __rtruediv__(self, o): return o / self.get()
|
||||||
def __floordiv__(self, o): return self.get() // o
|
|
||||||
|
def __floordiv__(self, o):
|
||||||
|
return self.get() // o
|
||||||
|
|
||||||
def __rfloordiv__(self, o): return o // self.get()
|
def __rfloordiv__(self, o): return o // self.get()
|
||||||
def __mod__(self, o): return self.get() % o
|
def __mod__(self, o): return self.get() % o
|
||||||
def __rmod__(self, o): return o % self.get()
|
def __rmod__(self, o): return o % self.get()
|
||||||
@ -1492,21 +1494,22 @@ class PerReplicaDataset(object):
|
|||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._worker_index = worker_index
|
self._worker_index = worker_index
|
||||||
|
|
||||||
# Default to using prefetching in graph mode, unless specified.
|
# Default to using prefetching, unless specified.
|
||||||
# TODO(rohanj): Enable prefetching in eager mode.
|
|
||||||
self._prefetch_on_device = prefetch_on_device
|
self._prefetch_on_device = prefetch_on_device
|
||||||
if self._prefetch_on_device is None:
|
if self._prefetch_on_device is None:
|
||||||
self._prefetch_on_device = not context.executing_eagerly()
|
self._prefetch_on_device = True
|
||||||
assert not (self._prefetch_on_device and context.executing_eagerly()), (
|
|
||||||
"Prefetching is only supported in graph mode currently")
|
|
||||||
|
|
||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
if not self._prefetch_on_device:
|
if not self._prefetch_on_device:
|
||||||
# TODO(priyag): If dropping remainder is not appropriate, find another
|
# TODO(priyag): If dropping remainder is not appropriate, find another
|
||||||
# approach to distributing the dataset when not possible to divide evenly.
|
# approach to distributing the dataset when not possible to divide evenly.
|
||||||
# Possibly not an issue when we start using PartitionedDataset.
|
# Possibly not an issue when we start using PartitionedDataset.
|
||||||
num_replicas = len(input_workers.compute_devices_for_worker(worker_index))
|
num_replicas = len(
|
||||||
self._dataset = dataset.batch(num_replicas, drop_remainder=True)
|
self._input_workers.compute_devices_for_worker(self._worker_index))
|
||||||
|
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
|
||||||
|
else:
|
||||||
|
self._replica_devices = self._input_workers.compute_devices_for_worker(
|
||||||
|
self._worker_index)
|
||||||
|
|
||||||
def make_one_shot_iterator(self):
|
def make_one_shot_iterator(self):
|
||||||
"""Get a one time use iterator for the distributed PerReplicaDataset."""
|
"""Get a one time use iterator for the distributed PerReplicaDataset."""
|
||||||
@ -1514,13 +1517,16 @@ class PerReplicaDataset(object):
|
|||||||
if not context.executing_eagerly():
|
if not context.executing_eagerly():
|
||||||
raise ValueError("Cannot create a one shot iterator. Please use "
|
raise ValueError("Cannot create a one shot iterator. Please use "
|
||||||
"`make_initializable_iterator()` instead.")
|
"`make_initializable_iterator()` instead.")
|
||||||
# Eager mode prefetching would error out in constructor. Only remaining
|
if self._prefetch_on_device:
|
||||||
# case is non-prefetching in eager mode. We delegate to
|
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
# PerReplicaDataIterator to handle that case.
|
self._dataset, self._replica_devices)
|
||||||
|
else:
|
||||||
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
|
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
|
||||||
return PerReplicaDataIterator(
|
return PerReplicaDataIterator(
|
||||||
dataset_iterator, self._input_workers, self._worker_index,
|
dataset_iterator,
|
||||||
prefetch_on_device=False)
|
self._input_workers,
|
||||||
|
self._worker_index,
|
||||||
|
prefetch_on_device=self._prefetch_on_device)
|
||||||
|
|
||||||
def make_initializable_iterator(self):
|
def make_initializable_iterator(self):
|
||||||
"""Get an initializable iterator for the distributed PerReplicaDataset."""
|
"""Get an initializable iterator for the distributed PerReplicaDataset."""
|
||||||
@ -1530,10 +1536,8 @@ class PerReplicaDataset(object):
|
|||||||
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
raise ValueError("Cannot create initializable iterator in Eager mode. "
|
||||||
"Please use `make_one_shot_iterator` instead.")
|
"Please use `make_one_shot_iterator` instead.")
|
||||||
if self._prefetch_on_device:
|
if self._prefetch_on_device:
|
||||||
replica_devices = self._input_workers.compute_devices_for_worker(
|
|
||||||
self._worker_index)
|
|
||||||
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
self._dataset, replica_devices)
|
self._dataset, self._replica_devices)
|
||||||
else:
|
else:
|
||||||
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
|
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
|
||||||
return PerReplicaDataIterator(
|
return PerReplicaDataIterator(
|
||||||
@ -1690,13 +1694,9 @@ class InputIteratorImpl(InputIterator):
|
|||||||
|
|
||||||
self._iterators = iterators
|
self._iterators = iterators
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._is_eager = context.executing_eagerly()
|
|
||||||
|
|
||||||
def get_next(self, name=None):
|
def get_next(self, name=None):
|
||||||
"""Returns the next input from the iterator for all replicas."""
|
"""Returns the next input from the iterator for all replicas."""
|
||||||
assert self._is_eager == context.executing_eagerly(), (
|
|
||||||
"Iterator should be created and used in same execution mode.")
|
|
||||||
|
|
||||||
replicas = []
|
replicas = []
|
||||||
for i, worker in enumerate(self._input_workers.worker_devices):
|
for i, worker in enumerate(self._input_workers.worker_devices):
|
||||||
if name is not None:
|
if name is not None:
|
||||||
@ -1716,9 +1716,6 @@ class InputIteratorImpl(InputIterator):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of any initializer ops that should be run.
|
A list of any initializer ops that should be run.
|
||||||
"""
|
"""
|
||||||
assert self._is_eager == context.executing_eagerly(), (
|
|
||||||
"Iterator should be created and used in same execution mode.")
|
|
||||||
|
|
||||||
init_ops = []
|
init_ops = []
|
||||||
for it in self._iterators:
|
for it in self._iterators:
|
||||||
init_ops.extend(it.initialize())
|
init_ops.extend(it.initialize())
|
||||||
@ -1842,7 +1839,7 @@ class _SingleWorkerDatasetIterator(object):
|
|||||||
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
|
||||||
|
|
||||||
`MultiDeviceIterator` is used to prefetch input to the devices on the
|
`MultiDeviceIterator` is used to prefetch input to the devices on the
|
||||||
given worker. `MultiDeviceIterator` doesn't work in eager mode yet.
|
given worker.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset: A `tf.data.Dataset` instance.
|
dataset: A `tf.data.Dataset` instance.
|
||||||
@ -1852,39 +1849,19 @@ class _SingleWorkerDatasetIterator(object):
|
|||||||
self._dataset = dataset
|
self._dataset = dataset
|
||||||
self._worker = worker
|
self._worker = worker
|
||||||
self._devices = devices
|
self._devices = devices
|
||||||
self._is_eager = context.executing_eagerly()
|
|
||||||
self._make_iterator()
|
self._make_iterator()
|
||||||
|
|
||||||
def _make_iterator(self):
|
def _make_iterator(self):
|
||||||
"""Make appropriate iterator on the dataset."""
|
"""Make appropriate iterator on the dataset."""
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
if self._is_eager:
|
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
||||||
# TODO(rohanj): Enable prefetching in eager mode.
|
|
||||||
# TODO(priyag): Measure the performance of this approach vs calling
|
|
||||||
# get_next on the original dataset N times.
|
|
||||||
dataset = self._dataset.batch(len(self._devices), drop_remainder=True)
|
|
||||||
iterator = dataset_ops.make_one_shot_iterator(dataset)
|
|
||||||
else:
|
|
||||||
iterator = multi_device_iterator_ops.MultiDeviceIterator(
|
|
||||||
self._dataset, self._devices)
|
self._dataset, self._devices)
|
||||||
self._iterator = iterator
|
|
||||||
|
|
||||||
def get_next_as_list(self, name=None):
|
def get_next_as_list(self, name=None):
|
||||||
"""Get next element from the underlying iterator."""
|
"""Get next element from the underlying iterator."""
|
||||||
|
del name
|
||||||
with ops.device(self._worker):
|
with ops.device(self._worker):
|
||||||
if self._is_eager:
|
|
||||||
# Batched dataset case.
|
|
||||||
batch = self._iterator.get_next(name=name)
|
|
||||||
data_list = []
|
|
||||||
for i, d in enumerate(self._devices):
|
|
||||||
v = nest.map_structure(operator.itemgetter(i), batch)
|
|
||||||
with ops.device(d):
|
|
||||||
v = nest.map_structure(array_ops.identity, v)
|
|
||||||
data_list.append(v)
|
|
||||||
else:
|
|
||||||
# MultiDeviceIterator case.
|
|
||||||
data_list = self._iterator.get_next()
|
data_list = self._iterator.get_next()
|
||||||
|
|
||||||
return data_list
|
return data_list
|
||||||
|
|
||||||
def initialize(self):
|
def initialize(self):
|
||||||
@ -1897,7 +1874,7 @@ class _SingleWorkerDatasetIterator(object):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of any initializer ops that should be run.
|
A list of any initializer ops that should be run.
|
||||||
"""
|
"""
|
||||||
if self._is_eager:
|
if context.executing_eagerly():
|
||||||
self._make_iterator()
|
self._make_iterator()
|
||||||
return []
|
return []
|
||||||
else:
|
else:
|
||||||
|
Loading…
Reference in New Issue
Block a user