From 2b56e1234b24a0287cd8df1061b512bb4e263c61 Mon Sep 17 00:00:00 2001 From: kushanam Date: Thu, 5 Nov 2020 12:47:38 -0800 Subject: [PATCH 1/8] add per_replica support for keras --- tensorflow/python/distribute/input_lib.py | 73 +++++++++++++++-------- 1 file changed, 47 insertions(+), 26 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 66460ebec44..633ee867004 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -29,6 +29,7 @@ from tensorflow.python.data.experimental.ops import cardinality from tensorflow.python.data.experimental.ops import distribute from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import multi_device_iterator_ops +from tensorflow.python.data.ops import iterator_ops from tensorflow.python.data.ops import optional_ops from tensorflow.python.distribute import device_util from tensorflow.python.distribute import distribute_utils @@ -759,11 +760,11 @@ class DistributedIteratorSpec(type_spec.TypeSpec): __slots__ = [ "_input_workers", "_element_spec", "_strategy", - "_enable_get_next_as_optional" + "_enable_get_next_as_optional", "_options" ] def __init__(self, input_workers, element_spec, strategy, - enable_get_next_as_optional): + enable_get_next_as_optional, options): # We don't want to allow deserialization of this class because we don't # serialize the strategy object. Currently the only places where # _deserialize is called is when we save/restore using SavedModels. @@ -775,6 +776,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec): self._element_spec = element_spec self._strategy = strategy self._enable_get_next_as_optional = enable_get_next_as_optional + self._options = options @property def value_type(self): @@ -784,7 +786,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec): # We cannot serialize the strategy object so we convert it to an id that we # can use for comparison. return (self._input_workers.serialize(), - self._element_spec, id(self._strategy)) + self._element_spec, id(self._strategy), id(self._options)) def _deserialize(self): raise ValueError("Deserialization is currently unsupported for " @@ -816,7 +818,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec): other._element_spec) return DistributedIteratorSpec(self._input_workers, element_spec, self._strategy, - self._enable_get_next_as_optional) + self._enable_get_next_as_optional, + self._options) @property def _component_specs(self): @@ -828,7 +831,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec): functools.partial(_replace_per_replica_spec, i=i), self._element_spec) specs.append(_SingleWorkerDatasetIteratorSpec(input_device, compute_devices, - element_spec)) + element_spec, + self._options)) return specs def _to_components(self, value): @@ -841,14 +845,16 @@ class DistributedIteratorSpec(type_spec.TypeSpec): components=components, element_spec=self._element_spec, strategy=self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + enable_get_next_as_optional=self._enable_get_next_as_optional, + options=self._options) @staticmethod def from_value(value): # pylint: disable=protected-access return DistributedIteratorSpec(value._input_workers, value._element_spec, value._strategy, - value._enable_get_next_as_optional) + value._enable_get_next_as_optional, + value._options) def _with_tensor_ranks_only(self): element_spec = nest.map_structure( @@ -856,7 +862,8 @@ class DistributedIteratorSpec(type_spec.TypeSpec): self._element_spec) return DistributedIteratorSpec(self._input_workers, element_spec, self._strategy, - self._enable_get_next_as_optional) + self._enable_get_next_as_optional, + self._options) class DistributedIterator(DistributedIteratorBase, @@ -869,7 +876,8 @@ class DistributedIterator(DistributedIteratorBase, strategy=None, components=None, element_spec=None, - enable_get_next_as_optional=False): + enable_get_next_as_optional=False, + options=None): if input_workers is None: raise ValueError("`input_workers` should be " "provided.") @@ -877,6 +885,7 @@ class DistributedIterator(DistributedIteratorBase, error_message = ("Either `input_workers` or " "both `components` and `element_spec` need to be " "provided.") + self._options = options if iterators is None: if (components is None or element_spec is None): @@ -916,7 +925,8 @@ class DistributedIterator(DistributedIteratorBase, # TODO(b/163362689): remove the comment after the bug if fixed. return DistributedIteratorSpec(self._input_workers, self._element_spec, self._strategy, - self._enable_get_next_as_optional) + self._enable_get_next_as_optional, + self._options) class _IterableInput(DistributedDatasetInterface): @@ -1290,7 +1300,8 @@ class DistributedDatasetsFromFunction(_IterableInput): input_workers=self._input_workers, iterators=iterators, strategy=self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + enable_get_next_as_optional=self._enable_get_next_as_optional, + options=self._options) iterator._element_spec = self._element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1585,9 +1596,7 @@ class _SingleWorkerDatasetIteratorBase(object): """Get next element for the given device.""" del name with ops.device(self._worker): - if isinstance(self._iterator, - (multi_device_iterator_ops.OwnedMultiDeviceIterator, - multi_device_iterator_ops.MultiDeviceIterator)): + if _should_use_multi_device_iterator(self._options): return self._iterator.get_next(device) else: return self._iterator.get_next() @@ -1665,25 +1674,30 @@ class _SingleWorkerDatasetIteratorBase(object): class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" - __slots__ = ["_worker", "_devices", "_element_spec"] + __slots__ = ["_worker", "_devices", "_element_spec", "_options"] def __init__(self, worker, devices, element_spec): self._worker = worker self._devices = tuple(device_util.canonicalize(d) for d in devices) self._element_spec = element_spec + self._options = options @property def value_type(self): return _SingleWorkerOwnedDatasetIterator def _serialize(self): - return (self._worker, self._devices, self._element_spec) + return (self._worker, self._devices, self._element_spec, self._options) @property def _component_specs(self): specs = [] - specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec( - self._devices, self._worker, element_spec=self._element_spec)) + if _should_use_multi_device_iterator(self._options): + specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec( + self._devices, self._worker, element_spec=self._element_spec)) + else: + specs.append(iterator_ops.IteratorSpec( + element_spec=self._element_spec)) return specs def _to_components(self, value): @@ -1695,13 +1709,14 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): worker=self._worker, devices=self._devices, components=components, - element_spec=self._element_spec) + element_spec=self._element_spec, + options=self._options) @staticmethod def from_value(value): # pylint: disable=protected-access return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, - value._element_spec) + value._element_spec, value._options) class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, @@ -1758,10 +1773,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, if not self._worker: raise ValueError("Worked device must be specified when creating an " "owned iterator.") - if (self._options is None or self._options.experimental_replication_mode == - InputReplicationMode.PER_WORKER or - (self._options.experimental_replication_mode == InputReplicationMode - .PER_REPLICA and self._options.experimental_prefetch_to_device)): + if _should_use_multi_device_iterator(self._options): host_device = device_util.get_host_for_device(self._worker) with ops.device(self._worker): self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( @@ -1777,7 +1789,7 @@ class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, @property def _type_spec(self): return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, - self._element_spec) + self._element_spec, self._options) @property def output_classes(self): @@ -1908,7 +1920,7 @@ def _create_iterators_per_worker(worker_datasets, options=options) else: iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, - worker_devices) + worker_devices, options) iterators.append(iterator) return iterators @@ -1988,6 +2000,15 @@ def _get_dataset_attributes(dataset): return batch_size, drop_remainder, prefetch_buffer +def _should_use_multi_device_iterator(options): + """Determine whether to use multi_device_iterator_ops.OwnedMultiDeviceIterator""" + if (options is None + or options.experimental_replication_mode == InputReplicationMode.PER_WORKER + or (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA + and options.experimental_prefetch_to_device)): + return True + return False + class MultiStepContext(object): """A context object that can be used to capture things when running steps. From 91519454ae996d8145b2500b4a5014b6b73e6787 Mon Sep 17 00:00:00 2001 From: kushanam Date: Thu, 5 Nov 2020 22:23:53 -0800 Subject: [PATCH 2/8] add missing options --- tensorflow/python/distribute/input_lib.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 633ee867004..3f3230517c6 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -1676,7 +1676,7 @@ class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): __slots__ = ["_worker", "_devices", "_element_spec", "_options"] - def __init__(self, worker, devices, element_spec): + def __init__(self, worker, devices, element_spec, options): self._worker = worker self._devices = tuple(device_util.canonicalize(d) for d in devices) self._element_spec = element_spec From 10592307811e8967aa305624e794ac1edabdbfc4 Mon Sep 17 00:00:00 2001 From: kushanam Date: Fri, 6 Nov 2020 14:20:37 -0800 Subject: [PATCH 3/8] adding typeSpec test --- .../distribute/input_lib_type_spec_test.py | 47 +++++++++++++++++++ 1 file changed, 47 insertions(+) diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 940949efd87..108731320ce 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -450,6 +450,53 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): f(v) self.assertEqual(self.trace_count, 1) + @combinations.generate( + combinations.combine( + mode=["eager"], + tf_api_version=2, + distribution=[ + strategy_combinations.mirrored_strategy_with_two_gpus, + strategy_combinations.mirrored_strategy_with_cpu_1_and_2, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + ], + enable_get_next_as_optional=[True, False])) + def testTypeSpecForPerReplicaOptions(self, distribution, enable_get_next_as_optional): + + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).map( + string_ops.string_to_number).batch( + input_context.get_per_replica_batch_size(4)) + + options = distribute_lib.InputOptions( + experimental_place_dataset_on_device = True, + experimental_prefetch_to_device = False, + experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA) + + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + + with distribution.scope(): + iterator = iter(ds) + _check_type_spec_structure(iterator) + + spec = iterator._type_spec + + tensor_list = spec._to_components(iterator) + re_iterator = spec._from_components(tensor_list) + + self.assertEqual(iterator._input_workers, re_iterator._input_workers) + self.assertAllEqual(iterator._iterators, re_iterator._iterators) + class RaggedTensorDistributedIteratorTest(test.TestCase, parameterized.TestCase): From bbe5cfdca903678c718dc98d1aed691bf2d8ae1c Mon Sep 17 00:00:00 2001 From: kushanam Date: Fri, 6 Nov 2020 15:00:01 -0800 Subject: [PATCH 4/8] adding iterator flatten by nest test --- .../distribute/input_lib_type_spec_test.py | 36 +++++++++++++++++++ 1 file changed, 36 insertions(+) diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 108731320ce..745194fd965 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -497,6 +497,42 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): self.assertEqual(iterator._input_workers, re_iterator._input_workers) self.assertAllEqual(iterator._iterators, re_iterator._iterators) + def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution, enable_get_next_as_optional): + + fname1 = os.path.join(self.get_temp_dir(), "1.txt") + _create_text_file(fname1, 5) + fname2 = os.path.join(self.get_temp_dir(), "2.txt") + _create_text_file(fname2, 9) + + def dataset_fn(input_context): + dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) + dataset = dataset.shard(input_context.num_input_pipelines, + input_context.input_pipeline_id) + return readers.TextLineDatasetV2(dataset).map( + string_ops.string_to_number).batch( + input_context.get_per_replica_batch_size(4)) + + options = distribute_lib.InputOptions( + experimental_place_dataset_on_device = True, + experimental_prefetch_to_device = False, + experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA) + + distribution.extended.experimental_enable_get_next_as_optional = ( + enable_get_next_as_optional) + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options) + _check_type_spec_structure(iter(ds)) + element_spec = ds.element_spec + iter_element_spec = iter(ds).element_spec + nest.assert_same_structure(element_spec, iter_element_spec) + self.assertAllEqual( + nest.flatten(element_spec), nest.flatten(iter_element_spec)) + + @def_function.function(input_signature=[element_spec]) + def process_inputs(inputs): + distribution.run(lambda inputs: inputs, args=(inputs,)) + + for x in ds: + process_inputs(x) class RaggedTensorDistributedIteratorTest(test.TestCase, parameterized.TestCase): From 27c89e1d29b62648bdd758d7009b41c02f4d55ad Mon Sep 17 00:00:00 2001 From: kushanam Date: Fri, 6 Nov 2020 16:24:25 -0800 Subject: [PATCH 5/8] address review changes --- .../distribute/input_lib_type_spec_test.py | 69 ++++++++----------- 1 file changed, 28 insertions(+), 41 deletions(-) diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 745194fd965..a0605229fb7 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -459,8 +459,27 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): strategy_combinations.mirrored_strategy_with_cpu_1_and_2, strategy_combinations.mirrored_strategy_with_gpu_and_cpu, ], - enable_get_next_as_optional=[True, False])) - def testTypeSpecForPerReplicaOptions(self, distribution, enable_get_next_as_optional): + enable_get_next_as_optional=[True, False], + input_options=[ + distribute_lib.InputOptions( + experimental_place_dataset_on_device=True, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=False, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + distribute_lib.InputOptions( + experimental_place_dataset_on_device=False, + experimental_prefetch_to_device=True, + experimental_replication_mode=distribute_lib + .InputReplicationMode.PER_REPLICA), + ],)) + def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution, + enable_get_next_as_optional, + input_options): fname1 = os.path.join(self.get_temp_dir(), "1.txt") _create_text_file(fname1, 5) @@ -473,59 +492,27 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): input_context.input_pipeline_id) return readers.TextLineDatasetV2(dataset).map( string_ops.string_to_number).batch( - input_context.get_per_replica_batch_size(4)) - - options = distribute_lib.InputOptions( - experimental_place_dataset_on_device = True, - experimental_prefetch_to_device = False, - experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA) - - ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options) + input_context.get_per_replica_batch_size(4)) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, + input_options) - with distribution.scope(): - iterator = iter(ds) - _check_type_spec_structure(iterator) - + iterator = iter(ds) + _check_type_spec_structure(iterator) spec = iterator._type_spec - tensor_list = spec._to_components(iterator) re_iterator = spec._from_components(tensor_list) - self.assertEqual(iterator._input_workers, re_iterator._input_workers) - self.assertAllEqual(iterator._iterators, re_iterator._iterators) - - def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution, enable_get_next_as_optional): - - fname1 = os.path.join(self.get_temp_dir(), "1.txt") - _create_text_file(fname1, 5) - fname2 = os.path.join(self.get_temp_dir(), "2.txt") - _create_text_file(fname2, 9) - - def dataset_fn(input_context): - dataset = dataset_ops.DatasetV2.from_tensor_slices([fname1, fname2]) - dataset = dataset.shard(input_context.num_input_pipelines, - input_context.input_pipeline_id) - return readers.TextLineDatasetV2(dataset).map( - string_ops.string_to_number).batch( - input_context.get_per_replica_batch_size(4)) - - options = distribute_lib.InputOptions( - experimental_place_dataset_on_device = True, - experimental_prefetch_to_device = False, - experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA) - - distribution.extended.experimental_enable_get_next_as_optional = ( - enable_get_next_as_optional) - ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options) _check_type_spec_structure(iter(ds)) element_spec = ds.element_spec iter_element_spec = iter(ds).element_spec nest.assert_same_structure(element_spec, iter_element_spec) self.assertAllEqual( nest.flatten(element_spec), nest.flatten(iter_element_spec)) + self.assertEqual(iterator._input_workers, re_iterator._input_workers) + self.assertAllEqual(iterator._iterators, re_iterator._iterators) @def_function.function(input_signature=[element_spec]) def process_inputs(inputs): From 309350e9de11a6120f3f623269990b0ae6607884 Mon Sep 17 00:00:00 2001 From: kushanam Date: Tue, 10 Nov 2020 07:54:53 -0800 Subject: [PATCH 6/8] apply test review change --- .../distribute/input_lib_type_spec_test.py | 37 ++++++++----------- 1 file changed, 16 insertions(+), 21 deletions(-) diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index a0605229fb7..55e830886b2 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -460,26 +460,17 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): strategy_combinations.mirrored_strategy_with_gpu_and_cpu, ], enable_get_next_as_optional=[True, False], - input_options=[ - distribute_lib.InputOptions( - experimental_place_dataset_on_device=True, - experimental_prefetch_to_device=False, - experimental_replication_mode=distribute_lib - .InputReplicationMode.PER_REPLICA), - distribute_lib.InputOptions( - experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=False, - experimental_replication_mode=distribute_lib - .InputReplicationMode.PER_REPLICA), - distribute_lib.InputOptions( - experimental_place_dataset_on_device=False, - experimental_prefetch_to_device=True, - experimental_replication_mode=distribute_lib - .InputReplicationMode.PER_REPLICA), - ],)) + experimental_place_dataset_on_device=[True,False], + experimental_prefetch_to_device=[True, False],)) def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution, enable_get_next_as_optional, - input_options): + experimental_place_dataset_on_device, + experimental_prefetch_to_device): + + if experimental_place_dataset_on_device and experimental_prefetch_to_device: + self.skipTest("Setting experimental_place_dataset_on_device and " + "experimental_prefetch_to_device to `True` is not allowed " + "when using distribute_lib.InputReplicationMode.PER_REPLICA.") fname1 = os.path.join(self.get_temp_dir(), "1.txt") _create_text_file(fname1, 5) @@ -492,12 +483,16 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): input_context.input_pipeline_id) return readers.TextLineDatasetV2(dataset).map( string_ops.string_to_number).batch( - input_context.get_per_replica_batch_size(4)) + input_context.get_per_replica_batch_size(4)) + + options = distribute_lib.InputOptions( + experimental_place_dataset_on_device = experimental_place_dataset_on_device, + experimental_prefetch_to_device = experimental_prefetch_to_device, + experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) - ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, - input_options) + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options) iterator = iter(ds) _check_type_spec_structure(iterator) From d3e2a16524a8ed6c51ab7f578f21afef360c44af Mon Sep 17 00:00:00 2001 From: kushanam Date: Tue, 10 Nov 2020 16:01:21 -0800 Subject: [PATCH 7/8] correct pylint formattings --- tensorflow/python/distribute/input_lib.py | 6 +++--- tensorflow/python/distribute/input_lib_type_spec_test.py | 8 ++++---- 2 files changed, 7 insertions(+), 7 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 3f3230517c6..0c9898f0588 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -2003,9 +2003,9 @@ def _get_dataset_attributes(dataset): def _should_use_multi_device_iterator(options): """Determine whether to use multi_device_iterator_ops.OwnedMultiDeviceIterator""" if (options is None - or options.experimental_replication_mode == InputReplicationMode.PER_WORKER - or (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA - and options.experimental_prefetch_to_device)): + or options.experimental_replication_mode == InputReplicationMode.PER_WORKER + or (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA + and options.experimental_prefetch_to_device)): return True return False diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 55e830886b2..2481a310a3a 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -460,7 +460,7 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): strategy_combinations.mirrored_strategy_with_gpu_and_cpu, ], enable_get_next_as_optional=[True, False], - experimental_place_dataset_on_device=[True,False], + experimental_place_dataset_on_device=[True, False], experimental_prefetch_to_device=[True, False],)) def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution, enable_get_next_as_optional, @@ -486,9 +486,9 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): input_context.get_per_replica_batch_size(4)) options = distribute_lib.InputOptions( - experimental_place_dataset_on_device = experimental_place_dataset_on_device, - experimental_prefetch_to_device = experimental_prefetch_to_device, - experimental_replication_mode = distribute_lib.InputReplicationMode.PER_REPLICA) + experimental_place_dataset_on_device=experimental_place_dataset_on_device, + experimental_prefetch_to_device=experimental_prefetch_to_device, + experimental_replication_mode=distribute_lib.InputReplicationMode.PER_REPLICA) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) From 9886910012d009c6b80c89b8c700d68c11e066d4 Mon Sep 17 00:00:00 2001 From: kushanam Date: Tue, 10 Nov 2020 18:06:55 -0800 Subject: [PATCH 8/8] correct pylint formattings - 2 --- tensorflow/python/distribute/input_lib.py | 9 ++++--- .../distribute/input_lib_type_spec_test.py | 26 +++++++++++-------- 2 files changed, 20 insertions(+), 15 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 0c9898f0588..f8c540c0d7f 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -2002,10 +2002,11 @@ def _get_dataset_attributes(dataset): def _should_use_multi_device_iterator(options): """Determine whether to use multi_device_iterator_ops.OwnedMultiDeviceIterator""" - if (options is None - or options.experimental_replication_mode == InputReplicationMode.PER_WORKER - or (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA - and options.experimental_prefetch_to_device)): + if (options is None or + options.experimental_replication_mode == InputReplicationMode.PER_WORKER + or + (options.experimental_replication_mode == InputReplicationMode.PER_REPLICA + and options.experimental_prefetch_to_device)): return True return False diff --git a/tensorflow/python/distribute/input_lib_type_spec_test.py b/tensorflow/python/distribute/input_lib_type_spec_test.py index 2481a310a3a..595f9b35b25 100644 --- a/tensorflow/python/distribute/input_lib_type_spec_test.py +++ b/tensorflow/python/distribute/input_lib_type_spec_test.py @@ -462,15 +462,16 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): enable_get_next_as_optional=[True, False], experimental_place_dataset_on_device=[True, False], experimental_prefetch_to_device=[True, False],)) - def testFromFunctionInputSignatureForPerReplicaValuesWithOptions(self, distribution, - enable_get_next_as_optional, - experimental_place_dataset_on_device, - experimental_prefetch_to_device): - + def testFromFunctionInputSignatureForPerReplicaValuesWithOptions( + self, distribution, enable_get_next_as_optional, + experimental_place_dataset_on_device, + experimental_prefetch_to_device): + if experimental_place_dataset_on_device and experimental_prefetch_to_device: self.skipTest("Setting experimental_place_dataset_on_device and " - "experimental_prefetch_to_device to `True` is not allowed " - "when using distribute_lib.InputReplicationMode.PER_REPLICA.") + "experimental_prefetch_to_device to `True` is not " + "allowed when using " + "distribute_lib.InputReplicationMode.PER_REPLICA.") fname1 = os.path.join(self.get_temp_dir(), "1.txt") _create_text_file(fname1, 5) @@ -486,13 +487,16 @@ class InputTypeSpecTest(test.TestCase, parameterized.TestCase): input_context.get_per_replica_batch_size(4)) options = distribute_lib.InputOptions( - experimental_place_dataset_on_device=experimental_place_dataset_on_device, - experimental_prefetch_to_device=experimental_prefetch_to_device, - experimental_replication_mode=distribute_lib.InputReplicationMode.PER_REPLICA) + experimental_place_dataset_on_device=experimental_place_dataset_on_device, + experimental_prefetch_to_device=experimental_prefetch_to_device, + experimental_replication_mode=( + distribute_lib.InputReplicationMode.PER_REPLICA + )) distribution.extended.experimental_enable_get_next_as_optional = ( enable_get_next_as_optional) - ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, options) + ds = distribution.experimental_distribute_datasets_from_function(dataset_fn, + options) iterator = iter(ds) _check_type_spec_structure(iterator)