diff --git a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py index 66117cf5b9d..2ae6a9893df 100644 --- a/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py +++ b/tensorflow/python/data/kernel_tests/multi_device_iterator_test.py @@ -188,7 +188,7 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase): self.evaluate(elem_on_2) @test_util.run_v1_only - def testMultipleInitializations(self): + def testMultipleInitializationsGraph(self): if context.executing_eagerly(): return @@ -209,6 +209,22 @@ class MultiDeviceIteratorTest(test_base.DatasetTestBase): self.assertEqual([(i, 0), (i, 1)], self.evaluate([elem_on_1, elem_on_2])) + @test_util.run_v1_only + def testMultipleInitializationsEager(self): + if not context.executing_eagerly(): + return + + with ops.device("/cpu:0"): + dataset1 = dataset_ops.Dataset.range(1000) + dataset2 = dataset_ops.Dataset.range(1000) + dataset = dataset_ops.Dataset.zip((dataset1, dataset2)) + + for _ in range(1000): + multi_device_iterator = multi_device_iterator_ops.MultiDeviceIterator( + dataset, ["/cpu:1", "/cpu:2"], prefetch_buffer_size=4) + elem_on_1, elem_on_2 = multi_device_iterator.get_next() + self.assertEqual([(0, 0), (1, 1)], self.evaluate([elem_on_1, elem_on_2])) + @test_util.run_v1_only def testBasicGpu(self): if not test_util.is_gpu_available(): diff --git a/tensorflow/python/data/ops/multi_device_iterator_ops.py b/tensorflow/python/data/ops/multi_device_iterator_ops.py index 2682e4acd0d..44aebb60cf5 100644 --- a/tensorflow/python/data/ops/multi_device_iterator_ops.py +++ b/tensorflow/python/data/ops/multi_device_iterator_ops.py @@ -160,10 +160,16 @@ class MultiDeviceIterator(object): # Create the MultiDeviceIterator. with ops.device(self._source_device): + # TODO(b/121378567): Get rid of this shared_name hack. + shared_name = "" + if context.executing_eagerly(): + # Ensure a unique name when eager execution is enabled to avoid spurious + # sharing issues. + shared_name += str(ops.uid()) self._multi_device_iterator_resource = ( gen_dataset_ops.multi_device_iterator( devices=self._devices, - shared_name="", + shared_name=shared_name, container="", **dataset_ops.flat_structure(dataset))) diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 1c84bc07b6c..abcc3b670da 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -20,7 +20,6 @@ from __future__ import print_function import collections import contextlib -import operator import weakref import six @@ -331,7 +330,10 @@ class DistributedDelegate(DistributedValues): def __rmul__(self, o): return o * self.get() def __truediv__(self, o): return self.get() / o def __rtruediv__(self, o): return o / self.get() - def __floordiv__(self, o): return self.get() // o + + def __floordiv__(self, o): + return self.get() // o + def __rfloordiv__(self, o): return o // self.get() def __mod__(self, o): return self.get() % o def __rmod__(self, o): return o % self.get() @@ -1492,21 +1494,22 @@ class PerReplicaDataset(object): self._input_workers = input_workers self._worker_index = worker_index - # Default to using prefetching in graph mode, unless specified. - # TODO(rohanj): Enable prefetching in eager mode. + # Default to using prefetching, unless specified. self._prefetch_on_device = prefetch_on_device if self._prefetch_on_device is None: - self._prefetch_on_device = not context.executing_eagerly() - assert not (self._prefetch_on_device and context.executing_eagerly()), ( - "Prefetching is only supported in graph mode currently") + self._prefetch_on_device = True self._dataset = dataset if not self._prefetch_on_device: # TODO(priyag): If dropping remainder is not appropriate, find another # approach to distributing the dataset when not possible to divide evenly. # Possibly not an issue when we start using PartitionedDataset. - num_replicas = len(input_workers.compute_devices_for_worker(worker_index)) - self._dataset = dataset.batch(num_replicas, drop_remainder=True) + num_replicas = len( + self._input_workers.compute_devices_for_worker(self._worker_index)) + self._dataset = self._dataset.batch(num_replicas, drop_remainder=True) + else: + self._replica_devices = self._input_workers.compute_devices_for_worker( + self._worker_index) def make_one_shot_iterator(self): """Get a one time use iterator for the distributed PerReplicaDataset.""" @@ -1514,13 +1517,16 @@ class PerReplicaDataset(object): if not context.executing_eagerly(): raise ValueError("Cannot create a one shot iterator. Please use " "`make_initializable_iterator()` instead.") - # Eager mode prefetching would error out in constructor. Only remaining - # case is non-prefetching in eager mode. We delegate to - # PerReplicaDataIterator to handle that case. - dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset) + if self._prefetch_on_device: + dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._replica_devices) + else: + dataset_iterator = dataset_ops.make_one_shot_iterator(self._dataset) return PerReplicaDataIterator( - dataset_iterator, self._input_workers, self._worker_index, - prefetch_on_device=False) + dataset_iterator, + self._input_workers, + self._worker_index, + prefetch_on_device=self._prefetch_on_device) def make_initializable_iterator(self): """Get an initializable iterator for the distributed PerReplicaDataset.""" @@ -1530,10 +1536,8 @@ class PerReplicaDataset(object): raise ValueError("Cannot create initializable iterator in Eager mode. " "Please use `make_one_shot_iterator` instead.") if self._prefetch_on_device: - replica_devices = self._input_workers.compute_devices_for_worker( - self._worker_index) dataset_iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, replica_devices) + self._dataset, self._replica_devices) else: dataset_iterator = dataset_ops.make_initializable_iterator(self._dataset) return PerReplicaDataIterator( @@ -1690,13 +1694,9 @@ class InputIteratorImpl(InputIterator): self._iterators = iterators self._input_workers = input_workers - self._is_eager = context.executing_eagerly() def get_next(self, name=None): """Returns the next input from the iterator for all replicas.""" - assert self._is_eager == context.executing_eagerly(), ( - "Iterator should be created and used in same execution mode.") - replicas = [] for i, worker in enumerate(self._input_workers.worker_devices): if name is not None: @@ -1716,9 +1716,6 @@ class InputIteratorImpl(InputIterator): Returns: A list of any initializer ops that should be run. """ - assert self._is_eager == context.executing_eagerly(), ( - "Iterator should be created and used in same execution mode.") - init_ops = [] for it in self._iterators: init_ops.extend(it.initialize()) @@ -1842,7 +1839,7 @@ class _SingleWorkerDatasetIterator(object): """Create iterator for the `dataset` to fetch data to worker's `devices` . `MultiDeviceIterator` is used to prefetch input to the devices on the - given worker. `MultiDeviceIterator` doesn't work in eager mode yet. + given worker. Args: dataset: A `tf.data.Dataset` instance. @@ -1852,39 +1849,19 @@ class _SingleWorkerDatasetIterator(object): self._dataset = dataset self._worker = worker self._devices = devices - self._is_eager = context.executing_eagerly() self._make_iterator() def _make_iterator(self): """Make appropriate iterator on the dataset.""" with ops.device(self._worker): - if self._is_eager: - # TODO(rohanj): Enable prefetching in eager mode. - # TODO(priyag): Measure the performance of this approach vs calling - # get_next on the original dataset N times. - dataset = self._dataset.batch(len(self._devices), drop_remainder=True) - iterator = dataset_ops.make_one_shot_iterator(dataset) - else: - iterator = multi_device_iterator_ops.MultiDeviceIterator( - self._dataset, self._devices) - self._iterator = iterator + self._iterator = multi_device_iterator_ops.MultiDeviceIterator( + self._dataset, self._devices) def get_next_as_list(self, name=None): """Get next element from the underlying iterator.""" + del name with ops.device(self._worker): - if self._is_eager: - # Batched dataset case. - batch = self._iterator.get_next(name=name) - data_list = [] - for i, d in enumerate(self._devices): - v = nest.map_structure(operator.itemgetter(i), batch) - with ops.device(d): - v = nest.map_structure(array_ops.identity, v) - data_list.append(v) - else: - # MultiDeviceIterator case. - data_list = self._iterator.get_next() - + data_list = self._iterator.get_next() return data_list def initialize(self): @@ -1897,7 +1874,7 @@ class _SingleWorkerDatasetIterator(object): Returns: A list of any initializer ops that should be run. """ - if self._is_eager: + if context.executing_eagerly(): self._make_iterator() return [] else: