From 0c9b20a3a41f9e0a188e3e0c6a7d147de066daed Mon Sep 17 00:00:00 2001 From: Jonathan DEKHTIAR Date: Tue, 3 Mar 2020 17:42:30 -0800 Subject: [PATCH 01/19] Allow Prefetching on GPU --- tensorflow/python/data/ops/dataset_ops.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index c7b2257c510..f7c3a3c25ce 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -4266,11 +4266,14 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): buffer_size = -1 # This is the sentinel for auto-tuning. self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - variant_tensor = gen_dataset_ops.prefetch_dataset( - input_dataset._variant_tensor, # pylint: disable=protected-access - buffer_size=self._buffer_size, - slack_period=slack_period, - **self._flat_structure) + + with ops.device(input_dataset._variant_tensor.device): + variant_tensor = gen_dataset_ops.prefetch_dataset( + input_dataset._variant_tensor, # pylint: disable=protected-access + buffer_size=self._buffer_size, + slack_period=slack_period, + **self._flat_structure) + super(PrefetchDataset, self).__init__(input_dataset, variant_tensor) From 23f61e1353ec125ce00f7b3fe343c5ccaaf3bb41 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 4 Mar 2020 18:07:37 -0800 Subject: [PATCH 02/19] Unittest Added to reproduce the bug --- .../data/experimental/kernel_tests/prefetch_to_device_test.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 8ac4e239881..02ea45c8a07 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -148,6 +148,8 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device("/gpu:0")) + self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() @@ -196,6 +198,8 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device("/gpu:0")) + self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() From d8795eae445873e6abcb8ca18f21495894b7b86a Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 4 Mar 2020 18:38:47 -0800 Subject: [PATCH 03/19] Improved Unittest --- .../kernel_tests/prefetch_to_device_test.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 02ea45c8a07..83ad036d5a2 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -146,13 +146,14 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + prefetching_ops.prefetch_to_device('/gpu:0')) iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() + self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + self.assertEqual(next_element.device, '/device:GPU:0') + with self.cached_session( config=config_pb2.ConfigProto(allow_soft_placement=False)): self.evaluate(iterator.initializer) @@ -196,13 +197,14 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) - - self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + prefetching_ops.prefetch_to_device('/gpu:0')) iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() + self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + self.assertEqual(next_element.device, '/device:GPU:0') + with self.cached_session( config=config_pb2.ConfigProto(allow_soft_placement=False)): self.evaluate(iterator.initializer) From 823befe5a183336d0fe5a68dec498bf83996af89 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 4 Mar 2020 18:39:14 -0800 Subject: [PATCH 04/19] iterator.get_next() colocation bug also fixed --- tensorflow/python/data/ops/dataset_ops.py | 2 +- tensorflow/python/data/ops/iterator_ops.py | 7 ++++--- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index f7c3a3c25ce..9990e866dc2 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -4267,7 +4267,7 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - with ops.device(input_dataset._variant_tensor.device): + with ops.colocate_with(input_dataset._variant_tensor.device): variant_tensor = gen_dataset_ops.prefetch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 668af74acf6..5c5c7c6d94c 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -419,13 +419,14 @@ class Iterator(trackable.Trackable): if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) - # pylint: disable=protected-access - flat_ret = gen_dataset_ops.iterator_get_next( + with ops.colocate_with(self._iterator_resource): + # pylint: disable=protected-access + flat_ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, output_types=self._flat_tensor_types, output_shapes=self._flat_tensor_shapes, name=name) - return structure.from_tensor_list(self._element_spec, flat_ret) + return structure.from_tensor_list(self._element_spec, flat_ret) def string_handle(self, name=None): """Returns a string-valued `tf.Tensor` that represents this iterator. From 1ae518c45885307fb7414fea9c0198f4fe379efd Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Fri, 6 Mar 2020 17:23:23 -0800 Subject: [PATCH 05/19] Additional GPU Prefetch Fix --- .../data/experimental/ops/prefetching_ops.py | 5 +- tensorflow/python/data/ops/dataset_ops.py | 93 ++++++++++--------- tensorflow/python/data/ops/iterator_ops.py | 3 +- 3 files changed, 54 insertions(+), 47 deletions(-) diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py index 9401ebd7bf0..5228c094b2e 100644 --- a/tensorflow/python/data/experimental/ops/prefetching_ops.py +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -70,12 +70,9 @@ def copy_to_device(target_device, source_device="/cpu:0"): """ def _apply_fn(dataset): - options = dataset_ops.Options() - options.experimental_optimization.apply_default_optimizations = False - options.experimental_optimization.autotune = False return _CopyToDeviceDataset( dataset, target_device=target_device, - source_device=source_device).with_options(options) + source_device=source_device) return _apply_fn diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 9990e866dc2..d86fb1b673a 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -344,49 +344,55 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def _apply_options(self): """Apply options, such as optimization configuration, to the dataset.""" + # TODO DEKHTIARJonathan: Remove when GPU OP exists + if "/device:GPU" in self._variant_tensor.device: + return self + dataset = self options = self.options() - # (1) Apply threading options - if options.experimental_threading is not None: - t_options = options.experimental_threading - if t_options.max_intra_op_parallelism is not None: - dataset = _MaxIntraOpParallelismDataset( - dataset, t_options.max_intra_op_parallelism) - if t_options.private_threadpool_size is not None: - dataset = _PrivateThreadPoolDataset(dataset, - t_options.private_threadpool_size) + with ops.colocate_with(dataset._variant_tensor): - # (2) Apply graph rewrite options - # pylint: disable=protected-access - graph_rewrites = options._graph_rewrites() - graph_rewrite_configs = options._graph_rewrite_configs() - # pylint: enable=protected-access - if graph_rewrites: - if self._has_captured_ref(): - warnings.warn( - "tf.data graph rewrites are not compatible with tf.Variable. " - "The following rewrites will be disabled: %s. To enable " - "rewrites, use resource variables instead by calling " - "`tf.enable_resource_variables()` at the start of the program." % - ", ".join(graph_rewrites)) - else: - dataset = _OptimizeDataset(dataset, graph_rewrites, - graph_rewrite_configs) + # (1) Apply threading options + if options.experimental_threading is not None: + t_options = options.experimental_threading + if t_options.max_intra_op_parallelism is not None: + dataset = _MaxIntraOpParallelismDataset( + dataset, t_options.max_intra_op_parallelism) + if t_options.private_threadpool_size is not None: + dataset = _PrivateThreadPoolDataset(dataset, + t_options.private_threadpool_size) - # (3) Apply autotune options - autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access + # (2) Apply graph rewrite options + # pylint: disable=protected-access + graph_rewrites = options._graph_rewrites() + graph_rewrite_configs = options._graph_rewrite_configs() + # pylint: enable=protected-access + if graph_rewrites: + if self._has_captured_ref(): + warnings.warn( + "tf.data graph rewrites are not compatible with tf.Variable. " + "The following rewrites will be disabled: %s. To enable " + "rewrites, use resource variables instead by calling " + "`tf.enable_resource_variables()` at the start of the program." % + ", ".join(graph_rewrites)) + else: + dataset = _OptimizeDataset(dataset, graph_rewrites, + graph_rewrite_configs) - if autotune: - dataset = _ModelDataset(dataset, algorithm, cpu_budget) + # (3) Apply autotune options + autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access - # (4) Apply stats aggregator options - if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long - dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access - dataset, options.experimental_stats.aggregator, - options.experimental_stats.prefix, - options.experimental_stats.counter_prefix) - return dataset + if autotune: + dataset = _ModelDataset(dataset, algorithm, cpu_budget) + + # (4) Apply stats aggregator options + if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long + dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access + dataset, options.experimental_stats.aggregator, + options.experimental_stats.prefix, + options.experimental_stats.counter_prefix) + return dataset def __iter__(self): """Creates an `Iterator` for enumerating the elements of this dataset. @@ -2194,14 +2200,17 @@ class DatasetV1(DatasetV2): dataset = self._apply_options() if shared_name is None: shared_name = "" - iterator_resource = gen_dataset_ops.iterator_v2( + + with ops.colocate_with(self._variant_tensor): + iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, **self._flat_structure) - with ops.colocate_with(iterator_resource): + initializer = gen_dataset_ops.make_iterator( dataset._variant_tensor, # pylint: disable=protected-access iterator_resource) - # pylint: disable=protected-access - return iterator_ops.Iterator( + + # pylint: disable=protected-access + return iterator_ops.Iterator( iterator_resource, initializer, get_legacy_output_types(dataset), get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) @@ -4266,8 +4275,8 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): buffer_size = -1 # This is the sentinel for auto-tuning. self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - - with ops.colocate_with(input_dataset._variant_tensor.device): + + with ops.colocate_with(input_dataset._variant_tensor): variant_tensor = gen_dataset_ops.prefetch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 5c5c7c6d94c..975620f9ecb 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -367,7 +367,8 @@ class Iterator(trackable.Trackable): raise TypeError("Expected output shapes compatible with %r but got " "dataset with output shapes %r." % (self.output_shapes, dataset_output_shapes)) - with ops.colocate_with(self._iterator_resource): + + with ops.colocate_with(dataset._variant_tensor): return gen_dataset_ops.make_iterator( dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access From d5e92b8492adbeb204b9ae521ffda2b2237cffb9 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Fri, 6 Mar 2020 17:25:22 -0800 Subject: [PATCH 06/19] Formatting fix tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py --- .../data/experimental/kernel_tests/prefetch_to_device_test.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 83ad036d5a2..8c6b1a03227 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -146,7 +146,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device('/gpu:0')) + prefetching_ops.prefetch_to_device("/gpu:0")) iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() @@ -197,7 +197,7 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device('/gpu:0')) + prefetching_ops.prefetch_to_device("/gpu:0")) iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() From c7ddf19da1de1e503ebc1af156e5713cd15f955b Mon Sep 17 00:00:00 2001 From: Jonathan DEKHTIAR Date: Tue, 7 Apr 2020 12:02:51 -0700 Subject: [PATCH 07/19] Update prefetching_ops.py --- tensorflow/python/data/experimental/ops/prefetching_ops.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py index 5228c094b2e..373504306dd 100644 --- a/tensorflow/python/data/experimental/ops/prefetching_ops.py +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -70,8 +70,12 @@ def copy_to_device(target_device, source_device="/cpu:0"): """ def _apply_fn(dataset): + options = dataset_ops.Options() + options.experimental_optimization.apply_default_optimizations = False + options.experimental_optimization.autotune = False return _CopyToDeviceDataset( dataset, target_device=target_device, + source_device=source_device).with_options(options) source_device=source_device) return _apply_fn From a61c961c81b4a885f01f7fa56655facc411b1014 Mon Sep 17 00:00:00 2001 From: Jonathan DEKHTIAR Date: Tue, 7 Apr 2020 12:04:31 -0700 Subject: [PATCH 08/19] Update dataset_ops.py --- tensorflow/python/data/ops/dataset_ops.py | 78 +++++++++++------------ 1 file changed, 36 insertions(+), 42 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d86fb1b673a..d2926ce018a 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -344,55 +344,49 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): def _apply_options(self): """Apply options, such as optimization configuration, to the dataset.""" - # TODO DEKHTIARJonathan: Remove when GPU OP exists - if "/device:GPU" in self._variant_tensor.device: - return self - dataset = self options = self.options() - with ops.colocate_with(dataset._variant_tensor): + # (1) Apply threading options + if options.experimental_threading is not None: + t_options = options.experimental_threading + if t_options.max_intra_op_parallelism is not None: + dataset = _MaxIntraOpParallelismDataset( + dataset, t_options.max_intra_op_parallelism) + if t_options.private_threadpool_size is not None: + dataset = _PrivateThreadPoolDataset(dataset, + t_options.private_threadpool_size) - # (1) Apply threading options - if options.experimental_threading is not None: - t_options = options.experimental_threading - if t_options.max_intra_op_parallelism is not None: - dataset = _MaxIntraOpParallelismDataset( - dataset, t_options.max_intra_op_parallelism) - if t_options.private_threadpool_size is not None: - dataset = _PrivateThreadPoolDataset(dataset, - t_options.private_threadpool_size) + # (2) Apply graph rewrite options + # pylint: disable=protected-access + graph_rewrites = options._graph_rewrites() + graph_rewrite_configs = options._graph_rewrite_configs() + # pylint: enable=protected-access + if graph_rewrites: + if self._has_captured_ref(): + warnings.warn( + "tf.data graph rewrites are not compatible with tf.Variable. " + "The following rewrites will be disabled: %s. To enable " + "rewrites, use resource variables instead by calling " + "`tf.enable_resource_variables()` at the start of the program." % + ", ".join(graph_rewrites)) + else: + dataset = _OptimizeDataset(dataset, graph_rewrites, + graph_rewrite_configs) - # (2) Apply graph rewrite options - # pylint: disable=protected-access - graph_rewrites = options._graph_rewrites() - graph_rewrite_configs = options._graph_rewrite_configs() - # pylint: enable=protected-access - if graph_rewrites: - if self._has_captured_ref(): - warnings.warn( - "tf.data graph rewrites are not compatible with tf.Variable. " - "The following rewrites will be disabled: %s. To enable " - "rewrites, use resource variables instead by calling " - "`tf.enable_resource_variables()` at the start of the program." % - ", ".join(graph_rewrites)) - else: - dataset = _OptimizeDataset(dataset, graph_rewrites, - graph_rewrite_configs) + # (3) Apply autotune options + autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access - # (3) Apply autotune options - autotune, algorithm, cpu_budget = options._autotune_settings() # pylint: disable=protected-access + if autotune: + dataset = _ModelDataset(dataset, algorithm, cpu_budget) - if autotune: - dataset = _ModelDataset(dataset, algorithm, cpu_budget) - - # (4) Apply stats aggregator options - if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long - dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access - dataset, options.experimental_stats.aggregator, - options.experimental_stats.prefix, - options.experimental_stats.counter_prefix) - return dataset + # (4) Apply stats aggregator options + if options.experimental_stats and options.experimental_stats.aggregator: # pylint: disable=line-too-long + dataset = _SetStatsAggregatorDataset( # pylint: disable=protected-access + dataset, options.experimental_stats.aggregator, + options.experimental_stats.prefix, + options.experimental_stats.counter_prefix) + return dataset def __iter__(self): """Creates an `Iterator` for enumerating the elements of this dataset. From b7552cff4e9bb4f9d0b5a9f80c8a607e8db82901 Mon Sep 17 00:00:00 2001 From: Jonathan DEKHTIAR Date: Tue, 7 Apr 2020 12:05:18 -0700 Subject: [PATCH 09/19] Update prefetching_ops.py --- tensorflow/python/data/experimental/ops/prefetching_ops.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tensorflow/python/data/experimental/ops/prefetching_ops.py b/tensorflow/python/data/experimental/ops/prefetching_ops.py index 373504306dd..9401ebd7bf0 100644 --- a/tensorflow/python/data/experimental/ops/prefetching_ops.py +++ b/tensorflow/python/data/experimental/ops/prefetching_ops.py @@ -76,7 +76,6 @@ def copy_to_device(target_device, source_device="/cpu:0"): return _CopyToDeviceDataset( dataset, target_device=target_device, source_device=source_device).with_options(options) - source_device=source_device) return _apply_fn From d9c7ef6adc1a6889c2130248bd9a74b08140d40e Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Tue, 7 Apr 2020 16:26:39 -0700 Subject: [PATCH 10/19] Change ops.colocate_with to ops.device --- tensorflow/python/data/ops/dataset_ops.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index d2926ce018a..0047c9263a0 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2195,7 +2195,7 @@ class DatasetV1(DatasetV2): if shared_name is None: shared_name = "" - with ops.colocate_with(self._variant_tensor): + with ops.device(self._variant_tensor.device): iterator_resource = gen_dataset_ops.iterator_v2( container="", shared_name=shared_name, **self._flat_structure) @@ -4270,7 +4270,7 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): self._buffer_size = ops.convert_to_tensor( buffer_size, dtype=dtypes.int64, name="buffer_size") - with ops.colocate_with(input_dataset._variant_tensor): + with ops.device(input_dataset._variant_tensor.device): variant_tensor = gen_dataset_ops.prefetch_dataset( input_dataset._variant_tensor, # pylint: disable=protected-access buffer_size=self._buffer_size, From 6dcbc61bfb05c87aa9b184a2feaf189e39fbe58d Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Tue, 7 Apr 2020 16:41:16 -0700 Subject: [PATCH 11/19] Change ops.colocate_with to ops.device in Iterator --- tensorflow/python/data/ops/iterator_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 975620f9ecb..bd904d3c2ff 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -420,7 +420,7 @@ class Iterator(trackable.Trackable): if self._get_next_call_count > GET_NEXT_CALL_WARNING_THRESHOLD: warnings.warn(GET_NEXT_CALL_WARNING_MESSAGE) - with ops.colocate_with(self._iterator_resource): + with ops.device(self._iterator_resource.device): # pylint: disable=protected-access flat_ret = gen_dataset_ops.iterator_get_next( self._iterator_resource, From 6fdaa569ac5a45c610dbfb59720c4441581bdd2e Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Tue, 7 Apr 2020 16:43:47 -0700 Subject: [PATCH 12/19] Change ops.colocate_with to ops.device in Iterator - 2 --- tensorflow/python/data/ops/iterator_ops.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index bd904d3c2ff..d69f33a70d5 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -368,7 +368,7 @@ class Iterator(trackable.Trackable): "dataset with output shapes %r." % (self.output_shapes, dataset_output_shapes)) - with ops.colocate_with(dataset._variant_tensor): + with ops.device(dataset._variant_tensor.device): return gen_dataset_ops.make_iterator( dataset._variant_tensor, self._iterator_resource, name=name) # pylint: disable=protected-access From e78dc79da958a50dcafae90bc362830b86af31c7 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 8 Apr 2020 16:22:11 -0700 Subject: [PATCH 13/19] Unittest fix --- .../experimental/kernel_tests/prefetch_to_device_test.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 872a1effdeb..34f235bd963 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -148,8 +148,12 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device("/gpu:0")) - self.assertEqual(host_dataset._variant_tensor.device, '/device:CPU:0') - self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') + self.assertTrue(( + "" == host_dataset._variant_tensor.device or + "CPU:0" in host_dataset._variant_tensor.device + )) + + self.assertTrue("GPU:0" in device_dataset._variant_tensor.device) self.assertDatasetProduces(device_dataset, list(range(10))) From 65b1a97ab98f09b197ee7b06a1ac6290349be816 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 8 Apr 2020 19:43:03 -0700 Subject: [PATCH 14/19] Unittests modifications as requested --- .../kernel_tests/prefetch_to_device_test.py | 17 ++-- tensorflow/python/data/kernel_tests/BUILD | 1 + .../python/data/kernel_tests/iterator_test.py | 81 +++++++++++++++++++ tensorflow/python/data/ops/dataset_ops.py | 24 +++--- 4 files changed, 109 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 34f235bd963..9f0ccefe4a1 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -148,6 +148,18 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): device_dataset = host_dataset.apply( prefetching_ops.prefetch_to_device("/gpu:0")) + self.assertDatasetProduces(device_dataset, list(range(10))) + + @combinations.generate(test_base.default_test_combinations()) + def testPrefetchToDeviceCorrectPlacement(self): + + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + self.assertTrue(( "" == host_dataset._variant_tensor.device or "CPU:0" in host_dataset._variant_tensor.device @@ -155,8 +167,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue("GPU:0" in device_dataset._variant_tensor.device) - self.assertDatasetProduces(device_dataset, list(range(10))) - @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceWithReInit(self): host_dataset = dataset_ops.Dataset.range(10) @@ -197,9 +207,6 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): iterator = dataset_ops.make_initializable_iterator(device_dataset) next_element = iterator.get_next() - self.assertEqual(device_dataset._variant_tensor.device, '/device:GPU:0') - self.assertEqual(next_element.device, '/device:GPU:0') - with self.cached_session( config=config_pb2.ConfigProto(allow_soft_placement=False)): self.evaluate(iterator.initializer) diff --git a/tensorflow/python/data/kernel_tests/BUILD b/tensorflow/python/data/kernel_tests/BUILD index eac7f4fd552..8f138b9ebc2 100644 --- a/tensorflow/python/data/kernel_tests/BUILD +++ b/tensorflow/python/data/kernel_tests/BUILD @@ -341,6 +341,7 @@ cuda_py_test( "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/compat", + "//tensorflow/python/data/experimental/ops:prefetching_ops", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:readers", diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 36689ed75fb..3e914f8a373 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -25,6 +25,8 @@ import numpy as np from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session +from tensorflow.python import tf2 +from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import iterator_ops @@ -1016,6 +1018,85 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): self.evaluate(counter_var.initializer) self.assertEqual(self.evaluate(fn()), 10) + def assert_dataset_placement(self, host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor): + + def assert_host_placement(_obj): + try: + self.assertIn("CPU:0", _obj) + except AssertionError: + self.assertEqual(_obj, "") + + assert_host_placement(host_dataset._variant_tensor.device) + assert_host_placement(host_tensor.device) + + self.assertIn("GPU:0", device_dataset._variant_tensor.device) + self.assertIn("GPU:0", device_tensor.device) + + if not tf2.enabled() or context.executing_eagerly(): + assert_host_placement(host_iterator._device) + self.assertIn("GPU:0", device_iterator._device) + + @combinations.generate(test_base.eager_only_combinations()) + def testIteratorOnDeviceEagerMode(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + host_iterator = iter(host_dataset) + device_iterator = iter(device_dataset) + + host_tensor = next(host_iterator) + device_tensor = next(device_iterator) + + self.assert_dataset_placement( + host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor + ) + + @combinations.generate(test_base.graph_only_combinations()) + def testIteratorOnDeviceGraphModeOneShotIterator(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + host_iterator = dataset_ops.make_one_shot_iterator(host_dataset) + device_iterator = dataset_ops.make_one_shot_iterator(device_dataset) + + host_tensor = host_iterator.get_next() + device_tensor = device_iterator.get_next() + + self.assert_dataset_placement( + host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor + ) + + @combinations.generate(test_base.graph_only_combinations()) + def testIteratorOnDeviceGraphModeInitializableIterator(self): + if not test_util.is_gpu_available(): + self.skipTest("No GPU available") + + host_dataset = dataset_ops.Dataset.range(10) + device_dataset = host_dataset.apply( + prefetching_ops.prefetch_to_device("/gpu:0")) + + host_iterator = dataset_ops.make_initializable_iterator(host_dataset) + device_iterator = dataset_ops.make_initializable_iterator(device_dataset) + + host_tensor = host_iterator.get_next() + device_tensor = device_iterator.get_next() + + self.assert_dataset_placement( + host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor + ) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 8d4e02717c6..40411e6a7a7 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -399,7 +399,8 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): RuntimeError: If not inside of tf.function and not executing eagerly. """ if context.executing_eagerly() or ops.inside_function(): - return iterator_ops.OwnedIterator(self) + with ops.device(self._variant_tensor.device): + return iterator_ops.OwnedIterator(self) else: raise RuntimeError("__iter__() is only supported inside of tf.function " "or when eager execution is enabled.") @@ -474,13 +475,15 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor): if not context.executing_eagerly(): raise RuntimeError("as_numpy_iterator() is not supported while tracing " "functions") + for component_spec in nest.flatten(self.element_spec): if not isinstance(component_spec, tensor_spec.TensorSpec): raise TypeError( "Dataset.as_numpy_iterator() does not support datasets containing " + str(component_spec.value_type)) - return _NumpyIterator(self) + with ops.device(self._variant_tensor.device): + return _NumpyIterator(self) @property def _flat_shapes(self): @@ -2144,8 +2147,10 @@ class DatasetV1(DatasetV2): return self._make_one_shot_iterator() def _make_one_shot_iterator(self): # pylint: disable=missing-docstring + if context.executing_eagerly(): - return iterator_ops.OwnedIterator(self) + with ops.device(self._variant_tensor.device): + return iterator_ops.OwnedIterator(self) _ensure_same_dataset_graph(self) # Now that we create datasets at python object creation time, the capture @@ -2187,12 +2192,13 @@ class DatasetV1(DatasetV2): else: six.reraise(ValueError, err) - # pylint: disable=protected-access - return iterator_ops.Iterator( - gen_dataset_ops.one_shot_iterator( - dataset_factory=_make_dataset, **self._flat_structure), None, - get_legacy_output_types(self), get_legacy_output_shapes(self), - get_legacy_output_classes(self)) + with ops.device(self._variant_tensor.device): + # pylint: disable=protected-access + return iterator_ops.Iterator( + gen_dataset_ops.one_shot_iterator( + dataset_factory=_make_dataset, **self._flat_structure), None, + get_legacy_output_types(self), get_legacy_output_shapes(self), + get_legacy_output_classes(self)) @deprecation.deprecated( None, "This is a deprecated API that should only be used in TF 1 graph " From 5f4fa6a193fe9a755841c82ec3f4654dc805f884 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Thu, 9 Apr 2020 11:10:58 -0700 Subject: [PATCH 15/19] Changes Requested --- .../kernel_tests/prefetch_to_device_test.py | 4 +-- .../python/data/kernel_tests/iterator_test.py | 27 ++++++++++--------- 2 files changed, 17 insertions(+), 14 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index 9f0ccefe4a1..b365c0b2a38 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -162,10 +162,10 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): self.assertTrue(( "" == host_dataset._variant_tensor.device or - "CPU:0" in host_dataset._variant_tensor.device + "cpu:0" in host_dataset._variant_tensor.device.lower() )) - self.assertTrue("GPU:0" in device_dataset._variant_tensor.device) + self.assertTrue("gpu:0" in device_dataset._variant_tensor.device.lower()) @combinations.generate(test_base.graph_only_combinations()) def testPrefetchToDeviceWithReInit(self): diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 3e914f8a373..73f258e4aa1 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -1021,21 +1021,24 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): def assert_dataset_placement(self, host_dataset, host_iterator, host_tensor, device_dataset, device_iterator, device_tensor): - def assert_host_placement(_obj): - try: - self.assertIn("CPU:0", _obj) - except AssertionError: - self.assertEqual(_obj, "") + self.assertTrue( + "cpu:0" in host_dataset._variant_tensor.device.lower() or + host_dataset._variant_tensor.device == "" + ) + self.assertTrue( + "cpu:0" in host_tensor._variant_tensor.device.lower() or + host_tensor._variant_tensor.device == "" + ) - assert_host_placement(host_dataset._variant_tensor.device) - assert_host_placement(host_tensor.device) - - self.assertIn("GPU:0", device_dataset._variant_tensor.device) - self.assertIn("GPU:0", device_tensor.device) + self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower()) + self.assertIn("gpu:0", device_tensor.device.lower()) if not tf2.enabled() or context.executing_eagerly(): - assert_host_placement(host_iterator._device) - self.assertIn("GPU:0", device_iterator._device) + self.assertTrue( + "cpu:0" in host_iterator._variant_tensor.device.lower() or + host_iterator._variant_tensor.device == "" + ) + self.assertIn("gpu:0", device_iterator._device.lower()) @combinations.generate(test_base.eager_only_combinations()) def testIteratorOnDeviceEagerMode(self): From 7b58da208ad33a8cefd27dd6c5fddcc35fab5c33 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Thu, 9 Apr 2020 11:14:53 -0700 Subject: [PATCH 16/19] Bug Fix --- tensorflow/python/data/kernel_tests/iterator_test.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 73f258e4aa1..9088b09c8cf 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -1026,8 +1026,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset._variant_tensor.device == "" ) self.assertTrue( - "cpu:0" in host_tensor._variant_tensor.device.lower() or - host_tensor._variant_tensor.device == "" + "cpu:0" in host_tensor.device.lower() or host_tensor.device == "" ) self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower()) @@ -1035,8 +1034,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): if not tf2.enabled() or context.executing_eagerly(): self.assertTrue( - "cpu:0" in host_iterator._variant_tensor.device.lower() or - host_iterator._variant_tensor.device == "" + "cpu:0" in host_iterator._device.lower() or host_iterator._device == "" ) self.assertIn("gpu:0", device_iterator._device.lower()) From aa329b63a9ea545139d1fc0a14ed1c7aa114da59 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 22 Apr 2020 14:35:45 -0700 Subject: [PATCH 17/19] Unittest Changes requested --- tensorflow/python/data/kernel_tests/iterator_test.py | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index 9088b09c8cf..b43f92ea687 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -1025,19 +1025,18 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): "cpu:0" in host_dataset._variant_tensor.device.lower() or host_dataset._variant_tensor.device == "" ) + self.assertTrue( + "cpu:0" in host_iterator._iterator_resource.device.lower() or + host_iterator._iterator_resource.device == "" + ) self.assertTrue( "cpu:0" in host_tensor.device.lower() or host_tensor.device == "" ) self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower()) + self.assertIn("gpu:0", device_iterator._iterator_resource.device.lower()) self.assertIn("gpu:0", device_tensor.device.lower()) - if not tf2.enabled() or context.executing_eagerly(): - self.assertTrue( - "cpu:0" in host_iterator._device.lower() or host_iterator._device == "" - ) - self.assertIn("gpu:0", device_iterator._device.lower()) - @combinations.generate(test_base.eager_only_combinations()) def testIteratorOnDeviceEagerMode(self): if not test_util.is_gpu_available(): From 0f26cfa90c2477392da1b1409f690c70cdf050f7 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Wed, 22 Apr 2020 15:39:43 -0700 Subject: [PATCH 18/19] Linting Issues Fix --- .../kernel_tests/prefetch_to_device_test.py | 4 +-- .../python/data/kernel_tests/iterator_test.py | 29 +++++++++---------- tensorflow/python/data/ops/iterator_ops.py | 8 ++--- 3 files changed, 20 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py index b365c0b2a38..52d33a3668d 100644 --- a/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/prefetch_to_device_test.py @@ -161,8 +161,8 @@ class PrefetchToDeviceTest(test_base.DatasetTestBase, parameterized.TestCase): prefetching_ops.prefetch_to_device("/gpu:0")) self.assertTrue(( - "" == host_dataset._variant_tensor.device or - "cpu:0" in host_dataset._variant_tensor.device.lower() + "" == host_dataset._variant_tensor.device or + "cpu:0" in host_dataset._variant_tensor.device.lower() )) self.assertTrue("gpu:0" in device_dataset._variant_tensor.device.lower()) diff --git a/tensorflow/python/data/kernel_tests/iterator_test.py b/tensorflow/python/data/kernel_tests/iterator_test.py index b43f92ea687..5d711d0a749 100644 --- a/tensorflow/python/data/kernel_tests/iterator_test.py +++ b/tensorflow/python/data/kernel_tests/iterator_test.py @@ -25,7 +25,6 @@ import numpy as np from tensorflow.core.protobuf import cluster_pb2 from tensorflow.core.protobuf import config_pb2 from tensorflow.python.client import session -from tensorflow.python import tf2 from tensorflow.python.data.experimental.ops import prefetching_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops @@ -1022,15 +1021,15 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): device_dataset, device_iterator, device_tensor): self.assertTrue( - "cpu:0" in host_dataset._variant_tensor.device.lower() or - host_dataset._variant_tensor.device == "" + "cpu:0" in host_dataset._variant_tensor.device.lower() or + host_dataset._variant_tensor.device == "" ) self.assertTrue( - "cpu:0" in host_iterator._iterator_resource.device.lower() or - host_iterator._iterator_resource.device == "" + "cpu:0" in host_iterator._iterator_resource.device.lower() or + host_iterator._iterator_resource.device == "" ) self.assertTrue( - "cpu:0" in host_tensor.device.lower() or host_tensor.device == "" + "cpu:0" in host_tensor.device.lower() or host_tensor.device == "" ) self.assertIn("gpu:0", device_dataset._variant_tensor.device.lower()) @@ -1044,7 +1043,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) + prefetching_ops.prefetch_to_device("/gpu:0")) host_iterator = iter(host_dataset) device_iterator = iter(device_dataset) @@ -1053,8 +1052,8 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): device_tensor = next(device_iterator) self.assert_dataset_placement( - host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, device_tensor + host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor ) @combinations.generate(test_base.graph_only_combinations()) @@ -1064,7 +1063,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) + prefetching_ops.prefetch_to_device("/gpu:0")) host_iterator = dataset_ops.make_one_shot_iterator(host_dataset) device_iterator = dataset_ops.make_one_shot_iterator(device_dataset) @@ -1073,8 +1072,8 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): device_tensor = device_iterator.get_next() self.assert_dataset_placement( - host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, device_tensor + host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor ) @combinations.generate(test_base.graph_only_combinations()) @@ -1084,7 +1083,7 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): host_dataset = dataset_ops.Dataset.range(10) device_dataset = host_dataset.apply( - prefetching_ops.prefetch_to_device("/gpu:0")) + prefetching_ops.prefetch_to_device("/gpu:0")) host_iterator = dataset_ops.make_initializable_iterator(host_dataset) device_iterator = dataset_ops.make_initializable_iterator(device_dataset) @@ -1093,8 +1092,8 @@ class IteratorTest(test_base.DatasetTestBase, parameterized.TestCase): device_tensor = device_iterator.get_next() self.assert_dataset_placement( - host_dataset, host_iterator, host_tensor, - device_dataset, device_iterator, device_tensor + host_dataset, host_iterator, host_tensor, + device_dataset, device_iterator, device_tensor ) diff --git a/tensorflow/python/data/ops/iterator_ops.py b/tensorflow/python/data/ops/iterator_ops.py index 10fc018615e..15a7e4bed8f 100644 --- a/tensorflow/python/data/ops/iterator_ops.py +++ b/tensorflow/python/data/ops/iterator_ops.py @@ -423,10 +423,10 @@ class Iterator(trackable.Trackable): with ops.device(self._iterator_resource.device): # pylint: disable=protected-access flat_ret = gen_dataset_ops.iterator_get_next( - self._iterator_resource, - output_types=self._flat_tensor_types, - output_shapes=self._flat_tensor_shapes, - name=name) + self._iterator_resource, + output_types=self._flat_tensor_types, + output_shapes=self._flat_tensor_shapes, + name=name) return structure.from_tensor_list(self._element_spec, flat_ret) def string_handle(self, name=None): From 356a3dd5a6acb5fe8f2a15dd6027d709bd7632f7 Mon Sep 17 00:00:00 2001 From: DEKHTIARJonathan Date: Thu, 23 Apr 2020 09:19:26 -0700 Subject: [PATCH 19/19] Linting Issues Fix --- tensorflow/python/data/ops/dataset_ops.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 40411e6a7a7..d5619370624 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2261,7 +2261,7 @@ class DatasetV1(DatasetV2): with ops.device(self._variant_tensor.device): iterator_resource = gen_dataset_ops.iterator_v2( - container="", shared_name=shared_name, **self._flat_structure) + container="", shared_name=shared_name, **self._flat_structure) initializer = gen_dataset_ops.make_iterator( dataset._variant_tensor, # pylint: disable=protected-access @@ -2269,8 +2269,8 @@ class DatasetV1(DatasetV2): # pylint: disable=protected-access return iterator_ops.Iterator( - iterator_resource, initializer, get_legacy_output_types(dataset), - get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) + iterator_resource, initializer, get_legacy_output_types(dataset), + get_legacy_output_shapes(dataset), get_legacy_output_classes(dataset)) @property @deprecation.deprecated( @@ -4380,7 +4380,7 @@ class PrefetchDataset(UnaryUnchangedStructureDataset): buffer_size=self._buffer_size, slack_period=slack_period, **self._flat_structure) - + super(PrefetchDataset, self).__init__(input_dataset, variant_tensor)