From 6c102af1716d0dffe9d111493b326a519970ebd3 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Fri, 18 Sep 2020 00:19:26 -0700 Subject: [PATCH] Move the logic to decide whether to enable partial batch handling into dataset When partial batch handling is enabled, we should change the element spec to have a unknown batch dimension. This will be done in a later change. PiperOrigin-RevId: 332398721 Change-Id: I3fe09c1a12fc7b393d94930aa0da4bddb7bb0666 --- tensorflow/python/distribute/input_lib.py | 189 ++++++++++------------ 1 file changed, 86 insertions(+), 103 deletions(-) diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 670325d87c3..d689346870e 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -552,41 +552,67 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False): return global_has_value, replicas -def _is_statically_shaped(element_spec): +def _is_statically_shaped(tensor_class, shape): """Test if an iterator output is statically shaped. For sparse and ragged tensors this only tests the batch dimension. Args: - element_spec: a nest structure of `tf.TypeSpec`. The element spec of the - dataset of the iterator. + tensor_class: a class from an iterator.output_classes list. + shape: a TensorShape from an iterator.output_shapes list. Returns: True if the shape is static, false otherwise. """ + if (tensor_class == sparse_tensor.SparseTensor or + isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)): + # For sparse or ragged tensor, we should only check the first + # dimension in order to get_next_as_optional. This is because + # when these tensors get batched by dataset only the batch dimension + # is set. + if shape.rank > 0 and shape.as_list()[0] is None: + return False + return True + return shape.is_fully_defined() - for spec in nest.flatten(element_spec): - shape = spec.shape - if isinstance( - spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)): - # For sparse or ragged tensor, we should only check the first - # dimension in order to get_next_as_optional. This is because - # when these tensors get batched by dataset only the batch dimension - # is set. - if shape.rank > 0 and shape.as_list()[0] is None: - return False - else: - if not shape.is_fully_defined(): - return False - return True + +def _get_static_shape(iterators): + """Returns a boolean indicating if the input is fully defined.""" + static_shape = True + for iterator in iterators: + if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator, + _SingleWorkerDatasetIterator)): + continue + flattened = zip(nest.flatten(iterator.output_shapes), + nest.flatten(iterator.output_classes)) + for output_shape, output_class in flattened: + if not _is_statically_shaped(output_class, output_shape): + static_shape = False + break + return static_shape class DistributedIteratorBase(DistributedIteratorInterface): """Common implementation for all input iterators.""" # pylint: disable=super-init-not-called - def __init__(self, input_workers, iterators, strategy, - enable_get_next_as_optional): + def __init__(self, input_workers, iterators, strategy): + static_shape = _get_static_shape(iterators) + + # TODO(b/133073708): we currently need a flag to control the usage because + # there is a performance difference between get_next() and + # get_next_as_optional(). And we only enable get_next_as_optional when the + # output shapes are not static. + # + # TODO(rxsang): We want to always enable the get_next_as_optional behavior + # when user passed input_fn instead of dataset. + if getattr( + strategy.extended, "experimental_enable_get_next_as_optional", False): + self._enable_get_next_as_optional = ( + not static_shape) or strategy.extended._in_multi_worker_mode() + else: + self._enable_get_next_as_optional = False + assert isinstance(input_workers, InputWorkers) if not input_workers.worker_devices: raise ValueError("Should have at least one worker for input iterator.") @@ -594,7 +620,6 @@ class DistributedIteratorBase(DistributedIteratorInterface): self._iterators = iterators self._input_workers = input_workers self._strategy = strategy - self._enable_get_next_as_optional = enable_get_next_as_optional def next(self): return self.__next__() @@ -728,13 +753,9 @@ class DistributedIteratorV1(DistributedIteratorBase): class DistributedIteratorSpec(type_spec.TypeSpec): """Type specification for `DistributedIterator`.""" - __slots__ = [ - "_input_workers", "_element_spec", "_strategy", - "_enable_get_next_as_optional" - ] + __slots__ = ["_input_workers", "_element_spec", "_strategy"] - def __init__(self, input_workers, element_spec, strategy, - enable_get_next_as_optional): + def __init__(self, input_workers, element_spec, strategy): # We don't want to allow deserialization of this class because we don't # serialize the strategy object. Currently the only places where # _deserialize is called is when we save/restore using SavedModels. @@ -745,7 +766,6 @@ class DistributedIteratorSpec(type_spec.TypeSpec): self._input_workers = input_workers self._element_spec = element_spec self._strategy = strategy - self._enable_get_next_as_optional = enable_get_next_as_optional @property def value_type(self): @@ -786,8 +806,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec): lambda a, b: a.most_specific_compatible_type(b), self._element_spec, other._element_spec) return DistributedIteratorSpec(self._input_workers, element_spec, - self._strategy, - self._enable_get_next_as_optional) + self._strategy) @property def _component_specs(self): @@ -806,41 +825,32 @@ class DistributedIteratorSpec(type_spec.TypeSpec): return value._iterators # pylint: disable=protected-access def _from_components(self, components): - return DistributedIterator( - input_workers=self._input_workers, - iterators=None, - components=components, - element_spec=self._element_spec, - strategy=self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + return DistributedIterator(input_workers=self._input_workers, + iterators=None, + components=components, + element_spec=self._element_spec, + strategy=self._strategy) @staticmethod def from_value(value): # pylint: disable=protected-access return DistributedIteratorSpec(value._input_workers, value._element_spec, - value._strategy, - value._enable_get_next_as_optional) + value._strategy) def _with_tensor_ranks_only(self): element_spec = nest.map_structure( lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access self._element_spec) return DistributedIteratorSpec(self._input_workers, element_spec, - self._strategy, - self._enable_get_next_as_optional) + self._strategy) class DistributedIterator(DistributedIteratorBase, composite_tensor.CompositeTensor): """Input Iterator for a distributed dataset.""" - def __init__(self, - input_workers=None, - iterators=None, - strategy=None, - components=None, - element_spec=None, - enable_get_next_as_optional=False): + def __init__(self, input_workers=None, iterators=None, strategy=None, + components=None, element_spec=None): if input_workers is None: raise ValueError("`input_workers` should be " "provided.") @@ -855,15 +865,20 @@ class DistributedIterator(DistributedIteratorBase, self._element_spec = element_spec self._input_workers = input_workers self._iterators = components + static_shape = _get_static_shape(self._iterators) self._strategy = strategy - self._enable_get_next_as_optional = enable_get_next_as_optional + if getattr(strategy.extended, + "experimental_enable_get_next_as_optional", False): + self._enable_get_next_as_optional = ( + not static_shape) or strategy.extended._in_multi_worker_mode() + else: + self._enable_get_next_as_optional = False else: if (components is not None and element_spec is not None): raise ValueError(error_message) - super(DistributedIterator, - self).__init__(input_workers, iterators, strategy, - enable_get_next_as_optional) + super(DistributedIterator, self).__init__(input_workers, iterators, + strategy) @property def element_spec(self): @@ -871,9 +886,9 @@ class DistributedIterator(DistributedIteratorBase, @property def _type_spec(self): - return DistributedIteratorSpec(self._input_workers, self.element_spec, - self._strategy, - self._enable_get_next_as_optional) + return DistributedIteratorSpec(self._input_workers, + self.element_spec, + self._strategy) class _IterableInput(DistributedDatasetInterface): @@ -986,8 +1001,6 @@ class DistributedDataset(_IterableInput): self._input_workers = input_workers self._strategy = strategy - self._enable_get_next_as_optional = _enable_get_next_as_optional( - self._strategy, dataset.element_spec) self._element_spec = _create_distributed_tensor_spec(self._strategy, dataset.element_spec) # pylint: disable=protected-access @@ -1006,17 +1019,11 @@ class DistributedDataset(_IterableInput): self._input_workers, enable_legacy_iterators) if enable_legacy_iterators: - iterator = DistributedIteratorV1( - self._input_workers, - worker_iterators, - self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + iterator = DistributedIteratorV1(self._input_workers, worker_iterators, + self._strategy) else: - iterator = DistributedIterator( - self._input_workers, - worker_iterators, - self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + iterator = DistributedIterator(self._input_workers, worker_iterators, + self._strategy) iterator._element_spec = self.element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1097,8 +1104,7 @@ class DistributedDatasetV1(DistributedDataset): self._input_workers, True) iterator = DistributedIteratorV1(self._input_workers, worker_iterators, - self._strategy, - self._enable_get_next_as_optional) + self._strategy) iterator._element_spec = self.element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1150,8 +1156,6 @@ class DistributedDatasetsFromFunction(_IterableInput): _create_datasets_per_worker_with_input_context(self._input_contexts, self._input_workers, dataset_fn)) - self._enable_get_next_as_optional = _enable_get_next_as_optional( - self._strategy, element_spec) self._element_spec = _create_distributed_tensor_spec( self._strategy, element_spec) @@ -1170,17 +1174,11 @@ class DistributedDatasetsFromFunction(_IterableInput): enable_legacy_iterators) if enable_legacy_iterators: - iterator = DistributedIteratorV1( - self._input_workers, - iterators, - self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + iterator = DistributedIteratorV1(self._input_workers, iterators, + self._strategy) else: - iterator = DistributedIterator( - self._input_workers, - iterators, - self._strategy, - enable_get_next_as_optional=self._enable_get_next_as_optional) + iterator = DistributedIterator(self._input_workers, iterators, + self._strategy) iterator._element_spec = self._element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1227,8 +1225,7 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): iterators = _create_iterators_per_worker(self._datasets, self._input_workers, True) iterator = DistributedIteratorV1(self._input_workers, iterators, - self._strategy, - self._enable_get_next_as_optional) + self._strategy) iterator._element_spec = self._element_spec # pylint: disable=protected-access # When async eager is enabled, sometimes the iterator may not finish @@ -1292,8 +1289,9 @@ class InputFunctionIterator(DistributedIteratorV1): "input_fn must return a tf.data.Dataset or a callable.") iterators.append(iterator) - super(InputFunctionIterator, self).__init__( - input_workers, iterators, strategy, enable_get_next_as_optional=False) + super(InputFunctionIterator, self).__init__(input_workers, iterators, + strategy) + self._enable_get_next_as_optional = False # TODO(anjalisridhar): This class will soon be removed and users should move @@ -1332,9 +1330,10 @@ class DatasetIterator(DistributedIteratorV1): input_context=input_context) worker_iterators = _create_iterators_per_worker( dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access - super(DatasetIterator, - self).__init__(input_workers, worker_iterators, strategy, - dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access + super(DatasetIterator, self).__init__( + input_workers, + worker_iterators, # pylint: disable=protected-access + strategy) self._element_spec = dist_dataset.element_spec @@ -1953,19 +1952,3 @@ def _replace_per_replica_spec(spec, i): return spec._value_specs[i] # pylint: disable=protected-access else: return spec - - -def _enable_get_next_as_optional(strategy, element_spec): - """Returns whether to enable using partial batch handling.""" - # TODO(b/133073708): we currently need a flag to control the usage because - # there is a performance difference between get_next() and - # get_next_as_optional(). And we only enable get_next_as_optional when the - # output shapes are not static. - # - # TODO(rxsang): We want to always enable the get_next_as_optional behavior - # when user passed input_fn instead of dataset. - if not getattr(strategy.extended, "experimental_enable_get_next_as_optional", - False): - return False - return not _is_statically_shaped( - element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access