Enabling prefetch_on_device in eager mode for distribution strategies

PiperOrigin-RevId: 226529748
This commit is contained in:
Rohan Jain 2018-12-21 12:54:25 -08:00 committed by TensorFlower Gardener
parent 181da18675
commit 08a692fe30
3 changed files with 52 additions and 53 deletions

View File

@ -188,7 +188,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
self.evaluate(elem_on_2)
@test_util.run_v1_only
def testMultipleInitializations(self):
def testMultipleInitializationsGraph(self):
if context.executing_eagerly():
return
@ -209,6 +209,22 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase):
self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1,
elem_on_2]))
@test_util.run_v1_only
def testMultipleInitializationsEager(self):
if not context.executing_eagerly():
return
with ops.device("/cpu:0"):
dataset1 = dataset_ops.Dataset.range(1000)
dataset2 = dataset_ops.Dataset.range(1000)
dataset = dataset_ops.Dataset.zip((dataset1, dataset2))
for _ in range(1000):
multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator(
dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4)
elem_on_1, elem_on_2 = multi_device_iterator.get_next()
self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2]))
@test_util.run_v1_only
def testBasicGpu(self):
if not test_util.is_gpu_available():

View File

@ -160,10 +160,16 @@ class MultiDeviceIterator(object):
# Create the MultiDeviceIterator.
with ops.device(self._source_device):
# TODO(b/121378567): Get rid of this shared_name hack.
shared_name = ""
if context.executing_eagerly():
# Ensure a unique name when eager execution is enabled to avoid spurious
# sharing issues.
shared_name += str(ops.uid())
self._multi_device_iterator_resource = (
gen_dataset_ops.multi_device_iterator(
devices=self._devices,
shared_name="",
shared_name=shared_name,
container="",
**dataset_ops.flat_structure(dataset)))

View File

@ -20,7 +20,6 @@ from __future__ import print_function
import collections
import contextlib
import operator
import weakref
import six
@ -331,7 +330,10 @@ class DistributedDelegate(DistributedValues):
def __rmul__(self, o): return o * self.get()
def __truediv__(self, o): return self.get() / o
def __rtruediv__(self, o): return o / self.get()
def __floordiv__(self, o): return self.get() // o
def __floordiv__(self, o):
return self.get() // o
def __rfloordiv__(self, o): return o // self.get()
def __mod__(self, o): return self.get() % o
def __rmod__(self, o): return o % self.get()
@ -1492,21 +1494,22 @@ class PerReplicaDataset(object):
self._input_workers = input_workers
self._worker_index = worker_index
# Default to using prefetching in graph mode, unless specified.
# TODO(rohanj): Enable prefetching in eager mode.
# Default to using prefetching, unless specified.
self._prefetch_on_device = prefetch_on_device
if self._prefetch_on_device is None:
self._prefetch_on_device = not context.executing_eagerly()
assert not (self._prefetch_on_device and context.executing_eagerly()), (
"Prefetching is only supported in graph mode currently")
self._prefetch_on_device = True
self._dataset = dataset
if not self._prefetch_on_device:
# TODO(priyag): If dropping remainder is not appropriate, find another
# approach to distributing the dataset when not possible to divide evenly.
# Possibly not an issue when we start using PartitionedDataset.
num_replicas = len(input_workers.compute_devices_for_worker(worker_index))
self._dataset = dataset.batch(num_replicas, drop_remainder=True)
num_replicas = len(
self._input_workers.compute_devices_for_worker(self._worker_index))
self._dataset = self._dataset.batch(num_replicas, drop_remainder=True)
else:
self._replica_devices = self._input_workers.compute_devices_for_worker(
self._worker_index)
def make_one_shot_iterator(self):
"""Get a one time use iterator for the distributed PerReplicaDataset."""
@ -1514,13 +1517,16 @@ class PerReplicaDataset(object):
if not context.executing_eagerly():
raise ValueError("Cannot create a one shot iterator. Please use "
"`make_initializable_iterator()` instead.")
# Eager mode prefetching would error out in constructor. Only remaining
# case is non-prefetching in eager mode. We delegate to
# PerReplicaDataIterator to handle that case.
if self._prefetch_on_device:
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._replica_devices)
else:
dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset)
return PerReplicaDataIterator(
dataset_iterator, self._input_workers, self._worker_index,
prefetch_on_device=False)
dataset_iterator,
self._input_workers,
self._worker_index,
prefetch_on_device=self._prefetch_on_device)
def make_initializable_iterator(self):
"""Get an initializable iterator for the distributed PerReplicaDataset."""
@ -1530,10 +1536,8 @@ class PerReplicaDataset(object):
raise ValueError("Cannot create initializable iterator in Eager mode. "
"Please use `make_one_shot_iterator` instead.")
if self._prefetch_on_device:
replica_devices = self._input_workers.compute_devices_for_worker(
self._worker_index)
dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, replica_devices)
self._dataset, self._replica_devices)
else:
dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset)
return PerReplicaDataIterator(
@ -1690,13 +1694,9 @@ class InputIteratorImpl(InputIterator):
self._iterators = iterators
self._input_workers = input_workers
self._is_eager = context.executing_eagerly()
def get_next(self, name=None):
"""Returns the next input from the iterator for all replicas."""
assert self._is_eager == context.executing_eagerly(), (
"Iterator should be created and used in same execution mode.")
replicas = []
for i, worker in enumerate(self._input_workers.worker_devices):
if name is not None:
@ -1716,9 +1716,6 @@ class InputIteratorImpl(InputIterator):
Returns:
A list of any initializer ops that should be run.
"""
assert self._is_eager == context.executing_eagerly(), (
"Iterator should be created and used in same execution mode.")
init_ops = []
for it in self._iterators:
init_ops.extend(it.initialize())
@ -1842,7 +1839,7 @@ class _SingleWorkerDatasetIterator(object):
"""Create iterator for the `dataset` to fetch data to worker's `devices` .
`MultiDeviceIterator` is used to prefetch input to the devices on the
given worker. `MultiDeviceIterator` doesn't work in eager mode yet.
given worker.
Args:
dataset: A `tf.data.Dataset` instance.
@ -1852,39 +1849,19 @@ class _SingleWorkerDatasetIterator(object):
self._dataset = dataset
self._worker = worker
self._devices = devices
self._is_eager = context.executing_eagerly()
self._make_iterator()
def _make_iterator(self):
"""Make appropriate iterator on the dataset."""
with ops.device(self._worker):
if self._is_eager:
# TODO(rohanj): Enable prefetching in eager mode.
# TODO(priyag): Measure the performance of this approach vs calling
# get_next on the original dataset N times.
dataset = self._dataset.batch(len(self._devices), drop_remainder=True)
iterator = dataset_ops.make_one_shot_iterator(dataset)
else:
iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._iterator = multi_device_iterator_ops.MultiDeviceIterator(
self._dataset, self._devices)
self._iterator = iterator
def get_next_as_list(self, name=None):
"""Get next element from the underlying iterator."""
del name
with ops.device(self._worker):
if self._is_eager:
# Batched dataset case.
batch = self._iterator.get_next(name=name)
data_list = []
for i, d in enumerate(self._devices):
v = nest.map_structure(operator.itemgetter(i), batch)
with ops.device(d):
v = nest.map_structure(array_ops.identity, v)
data_list.append(v)
else:
# MultiDeviceIterator case.
data_list = self._iterator.get_next()
return data_list
def initialize(self):
@ -1897,7 +1874,7 @@ class _SingleWorkerDatasetIterator(object):
Returns:
A list of any initializer ops that should be run.
"""
if self._is_eager:
if context.executing_eagerly():
self._make_iterator()
return []
else: