Move the logic to decide whether to enable partial batch handling into dataset
When partial batch handling is enabled, we should change the element spec to have a unknown batch dimension. This will be done in a later change. PiperOrigin-RevId: 332398721 Change-Id: I3fe09c1a12fc7b393d94930aa0da4bddb7bb0666
This commit is contained in:
parent
40a443b85b
commit
6c102af171
@ -552,41 +552,67 @@ def _get_next_as_optional(iterator, strategy, return_per_replica=False):
|
|||||||
return global_has_value, replicas
|
return global_has_value, replicas
|
||||||
|
|
||||||
|
|
||||||
def _is_statically_shaped(element_spec):
|
def _is_statically_shaped(tensor_class, shape):
|
||||||
"""Test if an iterator output is statically shaped.
|
"""Test if an iterator output is statically shaped.
|
||||||
|
|
||||||
For sparse and ragged tensors this only tests the batch dimension.
|
For sparse and ragged tensors this only tests the batch dimension.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
element_spec: a nest structure of `tf.TypeSpec`. The element spec of the
|
tensor_class: a class from an iterator.output_classes list.
|
||||||
dataset of the iterator.
|
shape: a TensorShape from an iterator.output_shapes list.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if the shape is static, false otherwise.
|
True if the shape is static, false otherwise.
|
||||||
"""
|
"""
|
||||||
|
if (tensor_class == sparse_tensor.SparseTensor or
|
||||||
|
isinstance(tensor_class, ragged_tensor.RaggedTensorSpec)):
|
||||||
|
# For sparse or ragged tensor, we should only check the first
|
||||||
|
# dimension in order to get_next_as_optional. This is because
|
||||||
|
# when these tensors get batched by dataset only the batch dimension
|
||||||
|
# is set.
|
||||||
|
if shape.rank > 0 and shape.as_list()[0] is None:
|
||||||
|
return False
|
||||||
|
return True
|
||||||
|
return shape.is_fully_defined()
|
||||||
|
|
||||||
for spec in nest.flatten(element_spec):
|
|
||||||
shape = spec.shape
|
def _get_static_shape(iterators):
|
||||||
if isinstance(
|
"""Returns a boolean indicating if the input is fully defined."""
|
||||||
spec, (sparse_tensor.SparseTensorSpec, ragged_tensor.RaggedTensorSpec)):
|
static_shape = True
|
||||||
# For sparse or ragged tensor, we should only check the first
|
for iterator in iterators:
|
||||||
# dimension in order to get_next_as_optional. This is because
|
if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator,
|
||||||
# when these tensors get batched by dataset only the batch dimension
|
_SingleWorkerDatasetIterator)):
|
||||||
# is set.
|
continue
|
||||||
if shape.rank > 0 and shape.as_list()[0] is None:
|
flattened = zip(nest.flatten(iterator.output_shapes),
|
||||||
return False
|
nest.flatten(iterator.output_classes))
|
||||||
else:
|
for output_shape, output_class in flattened:
|
||||||
if not shape.is_fully_defined():
|
if not _is_statically_shaped(output_class, output_shape):
|
||||||
return False
|
static_shape = False
|
||||||
return True
|
break
|
||||||
|
return static_shape
|
||||||
|
|
||||||
|
|
||||||
class DistributedIteratorBase(DistributedIteratorInterface):
|
class DistributedIteratorBase(DistributedIteratorInterface):
|
||||||
"""Common implementation for all input iterators."""
|
"""Common implementation for all input iterators."""
|
||||||
|
|
||||||
# pylint: disable=super-init-not-called
|
# pylint: disable=super-init-not-called
|
||||||
def __init__(self, input_workers, iterators, strategy,
|
def __init__(self, input_workers, iterators, strategy):
|
||||||
enable_get_next_as_optional):
|
static_shape = _get_static_shape(iterators)
|
||||||
|
|
||||||
|
# TODO(b/133073708): we currently need a flag to control the usage because
|
||||||
|
# there is a performance difference between get_next() and
|
||||||
|
# get_next_as_optional(). And we only enable get_next_as_optional when the
|
||||||
|
# output shapes are not static.
|
||||||
|
#
|
||||||
|
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
|
||||||
|
# when user passed input_fn instead of dataset.
|
||||||
|
if getattr(
|
||||||
|
strategy.extended, "experimental_enable_get_next_as_optional", False):
|
||||||
|
self._enable_get_next_as_optional = (
|
||||||
|
not static_shape) or strategy.extended._in_multi_worker_mode()
|
||||||
|
else:
|
||||||
|
self._enable_get_next_as_optional = False
|
||||||
|
|
||||||
assert isinstance(input_workers, InputWorkers)
|
assert isinstance(input_workers, InputWorkers)
|
||||||
if not input_workers.worker_devices:
|
if not input_workers.worker_devices:
|
||||||
raise ValueError("Should have at least one worker for input iterator.")
|
raise ValueError("Should have at least one worker for input iterator.")
|
||||||
@ -594,7 +620,6 @@ class DistributedIteratorBase(DistributedIteratorInterface):
|
|||||||
self._iterators = iterators
|
self._iterators = iterators
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
|
||||||
|
|
||||||
def next(self):
|
def next(self):
|
||||||
return self.__next__()
|
return self.__next__()
|
||||||
@ -728,13 +753,9 @@ class DistributedIteratorV1(DistributedIteratorBase):
|
|||||||
class DistributedIteratorSpec(type_spec.TypeSpec):
|
class DistributedIteratorSpec(type_spec.TypeSpec):
|
||||||
"""Type specification for `DistributedIterator`."""
|
"""Type specification for `DistributedIterator`."""
|
||||||
|
|
||||||
__slots__ = [
|
__slots__ = ["_input_workers", "_element_spec", "_strategy"]
|
||||||
"_input_workers", "_element_spec", "_strategy",
|
|
||||||
"_enable_get_next_as_optional"
|
|
||||||
]
|
|
||||||
|
|
||||||
def __init__(self, input_workers, element_spec, strategy,
|
def __init__(self, input_workers, element_spec, strategy):
|
||||||
enable_get_next_as_optional):
|
|
||||||
# We don't want to allow deserialization of this class because we don't
|
# We don't want to allow deserialization of this class because we don't
|
||||||
# serialize the strategy object. Currently the only places where
|
# serialize the strategy object. Currently the only places where
|
||||||
# _deserialize is called is when we save/restore using SavedModels.
|
# _deserialize is called is when we save/restore using SavedModels.
|
||||||
@ -745,7 +766,6 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._element_spec = element_spec
|
self._element_spec = element_spec
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def value_type(self):
|
def value_type(self):
|
||||||
@ -786,8 +806,7 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
|
lambda a, b: a.most_specific_compatible_type(b), self._element_spec,
|
||||||
other._element_spec)
|
other._element_spec)
|
||||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||||
self._strategy,
|
self._strategy)
|
||||||
self._enable_get_next_as_optional)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _component_specs(self):
|
def _component_specs(self):
|
||||||
@ -806,41 +825,32 @@ class DistributedIteratorSpec(type_spec.TypeSpec):
|
|||||||
return value._iterators # pylint: disable=protected-access
|
return value._iterators # pylint: disable=protected-access
|
||||||
|
|
||||||
def _from_components(self, components):
|
def _from_components(self, components):
|
||||||
return DistributedIterator(
|
return DistributedIterator(input_workers=self._input_workers,
|
||||||
input_workers=self._input_workers,
|
iterators=None,
|
||||||
iterators=None,
|
components=components,
|
||||||
components=components,
|
element_spec=self._element_spec,
|
||||||
element_spec=self._element_spec,
|
strategy=self._strategy)
|
||||||
strategy=self._strategy,
|
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_value(value):
|
def from_value(value):
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
return DistributedIteratorSpec(value._input_workers, value._element_spec,
|
||||||
value._strategy,
|
value._strategy)
|
||||||
value._enable_get_next_as_optional)
|
|
||||||
|
|
||||||
def _with_tensor_ranks_only(self):
|
def _with_tensor_ranks_only(self):
|
||||||
element_spec = nest.map_structure(
|
element_spec = nest.map_structure(
|
||||||
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
|
lambda s: s._with_tensor_ranks_only(), # pylint: disable=protected-access
|
||||||
self._element_spec)
|
self._element_spec)
|
||||||
return DistributedIteratorSpec(self._input_workers, element_spec,
|
return DistributedIteratorSpec(self._input_workers, element_spec,
|
||||||
self._strategy,
|
self._strategy)
|
||||||
self._enable_get_next_as_optional)
|
|
||||||
|
|
||||||
|
|
||||||
class DistributedIterator(DistributedIteratorBase,
|
class DistributedIterator(DistributedIteratorBase,
|
||||||
composite_tensor.CompositeTensor):
|
composite_tensor.CompositeTensor):
|
||||||
"""Input Iterator for a distributed dataset."""
|
"""Input Iterator for a distributed dataset."""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self, input_workers=None, iterators=None, strategy=None,
|
||||||
input_workers=None,
|
components=None, element_spec=None):
|
||||||
iterators=None,
|
|
||||||
strategy=None,
|
|
||||||
components=None,
|
|
||||||
element_spec=None,
|
|
||||||
enable_get_next_as_optional=False):
|
|
||||||
if input_workers is None:
|
if input_workers is None:
|
||||||
raise ValueError("`input_workers` should be "
|
raise ValueError("`input_workers` should be "
|
||||||
"provided.")
|
"provided.")
|
||||||
@ -855,15 +865,20 @@ class DistributedIterator(DistributedIteratorBase,
|
|||||||
self._element_spec = element_spec
|
self._element_spec = element_spec
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._iterators = components
|
self._iterators = components
|
||||||
|
static_shape = _get_static_shape(self._iterators)
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._enable_get_next_as_optional = enable_get_next_as_optional
|
if getattr(strategy.extended,
|
||||||
|
"experimental_enable_get_next_as_optional", False):
|
||||||
|
self._enable_get_next_as_optional = (
|
||||||
|
not static_shape) or strategy.extended._in_multi_worker_mode()
|
||||||
|
else:
|
||||||
|
self._enable_get_next_as_optional = False
|
||||||
else:
|
else:
|
||||||
if (components is not None and element_spec is not None):
|
if (components is not None and element_spec is not None):
|
||||||
raise ValueError(error_message)
|
raise ValueError(error_message)
|
||||||
|
|
||||||
super(DistributedIterator,
|
super(DistributedIterator, self).__init__(input_workers, iterators,
|
||||||
self).__init__(input_workers, iterators, strategy,
|
strategy)
|
||||||
enable_get_next_as_optional)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def element_spec(self):
|
def element_spec(self):
|
||||||
@ -871,9 +886,9 @@ class DistributedIterator(DistributedIteratorBase,
|
|||||||
|
|
||||||
@property
|
@property
|
||||||
def _type_spec(self):
|
def _type_spec(self):
|
||||||
return DistributedIteratorSpec(self._input_workers, self.element_spec,
|
return DistributedIteratorSpec(self._input_workers,
|
||||||
self._strategy,
|
self.element_spec,
|
||||||
self._enable_get_next_as_optional)
|
self._strategy)
|
||||||
|
|
||||||
|
|
||||||
class _IterableInput(DistributedDatasetInterface):
|
class _IterableInput(DistributedDatasetInterface):
|
||||||
@ -986,8 +1001,6 @@ class DistributedDataset(_IterableInput):
|
|||||||
|
|
||||||
self._input_workers = input_workers
|
self._input_workers = input_workers
|
||||||
self._strategy = strategy
|
self._strategy = strategy
|
||||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
|
||||||
self._strategy, dataset.element_spec)
|
|
||||||
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
self._element_spec = _create_distributed_tensor_spec(self._strategy,
|
||||||
dataset.element_spec) # pylint: disable=protected-access
|
dataset.element_spec) # pylint: disable=protected-access
|
||||||
|
|
||||||
@ -1006,17 +1019,11 @@ class DistributedDataset(_IterableInput):
|
|||||||
self._input_workers,
|
self._input_workers,
|
||||||
enable_legacy_iterators)
|
enable_legacy_iterators)
|
||||||
if enable_legacy_iterators:
|
if enable_legacy_iterators:
|
||||||
iterator = DistributedIteratorV1(
|
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||||
self._input_workers,
|
self._strategy)
|
||||||
worker_iterators,
|
|
||||||
self._strategy,
|
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
|
||||||
else:
|
else:
|
||||||
iterator = DistributedIterator(
|
iterator = DistributedIterator(self._input_workers, worker_iterators,
|
||||||
self._input_workers,
|
self._strategy)
|
||||||
worker_iterators,
|
|
||||||
self._strategy,
|
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
|
||||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
# When async eager is enabled, sometimes the iterator may not finish
|
# When async eager is enabled, sometimes the iterator may not finish
|
||||||
@ -1097,8 +1104,7 @@ class DistributedDatasetV1(DistributedDataset):
|
|||||||
self._input_workers,
|
self._input_workers,
|
||||||
True)
|
True)
|
||||||
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
iterator = DistributedIteratorV1(self._input_workers, worker_iterators,
|
||||||
self._strategy,
|
self._strategy)
|
||||||
self._enable_get_next_as_optional)
|
|
||||||
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
iterator._element_spec = self.element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
# When async eager is enabled, sometimes the iterator may not finish
|
# When async eager is enabled, sometimes the iterator may not finish
|
||||||
@ -1150,8 +1156,6 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
|||||||
_create_datasets_per_worker_with_input_context(self._input_contexts,
|
_create_datasets_per_worker_with_input_context(self._input_contexts,
|
||||||
self._input_workers,
|
self._input_workers,
|
||||||
dataset_fn))
|
dataset_fn))
|
||||||
self._enable_get_next_as_optional = _enable_get_next_as_optional(
|
|
||||||
self._strategy, element_spec)
|
|
||||||
self._element_spec = _create_distributed_tensor_spec(
|
self._element_spec = _create_distributed_tensor_spec(
|
||||||
self._strategy, element_spec)
|
self._strategy, element_spec)
|
||||||
|
|
||||||
@ -1170,17 +1174,11 @@ class DistributedDatasetsFromFunction(_IterableInput):
|
|||||||
enable_legacy_iterators)
|
enable_legacy_iterators)
|
||||||
|
|
||||||
if enable_legacy_iterators:
|
if enable_legacy_iterators:
|
||||||
iterator = DistributedIteratorV1(
|
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
||||||
self._input_workers,
|
self._strategy)
|
||||||
iterators,
|
|
||||||
self._strategy,
|
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
|
||||||
else:
|
else:
|
||||||
iterator = DistributedIterator(
|
iterator = DistributedIterator(self._input_workers, iterators,
|
||||||
self._input_workers,
|
self._strategy)
|
||||||
iterators,
|
|
||||||
self._strategy,
|
|
||||||
enable_get_next_as_optional=self._enable_get_next_as_optional)
|
|
||||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
# When async eager is enabled, sometimes the iterator may not finish
|
# When async eager is enabled, sometimes the iterator may not finish
|
||||||
@ -1227,8 +1225,7 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction):
|
|||||||
iterators = _create_iterators_per_worker(self._datasets,
|
iterators = _create_iterators_per_worker(self._datasets,
|
||||||
self._input_workers, True)
|
self._input_workers, True)
|
||||||
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
iterator = DistributedIteratorV1(self._input_workers, iterators,
|
||||||
self._strategy,
|
self._strategy)
|
||||||
self._enable_get_next_as_optional)
|
|
||||||
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
iterator._element_spec = self._element_spec # pylint: disable=protected-access
|
||||||
|
|
||||||
# When async eager is enabled, sometimes the iterator may not finish
|
# When async eager is enabled, sometimes the iterator may not finish
|
||||||
@ -1292,8 +1289,9 @@ class InputFunctionIterator(DistributedIteratorV1):
|
|||||||
"input_fn must return a tf.data.Dataset or a callable.")
|
"input_fn must return a tf.data.Dataset or a callable.")
|
||||||
iterators.append(iterator)
|
iterators.append(iterator)
|
||||||
|
|
||||||
super(InputFunctionIterator, self).__init__(
|
super(InputFunctionIterator, self).__init__(input_workers, iterators,
|
||||||
input_workers, iterators, strategy, enable_get_next_as_optional=False)
|
strategy)
|
||||||
|
self._enable_get_next_as_optional = False
|
||||||
|
|
||||||
|
|
||||||
# TODO(anjalisridhar): This class will soon be removed and users should move
|
# TODO(anjalisridhar): This class will soon be removed and users should move
|
||||||
@ -1332,9 +1330,10 @@ class DatasetIterator(DistributedIteratorV1):
|
|||||||
input_context=input_context)
|
input_context=input_context)
|
||||||
worker_iterators = _create_iterators_per_worker(
|
worker_iterators = _create_iterators_per_worker(
|
||||||
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
dist_dataset._cloned_datasets, input_workers, True) # pylint: disable=protected-access
|
||||||
super(DatasetIterator,
|
super(DatasetIterator, self).__init__(
|
||||||
self).__init__(input_workers, worker_iterators, strategy,
|
input_workers,
|
||||||
dist_dataset._enable_get_next_as_optional) # pylint: disable=protected-access
|
worker_iterators, # pylint: disable=protected-access
|
||||||
|
strategy)
|
||||||
self._element_spec = dist_dataset.element_spec
|
self._element_spec = dist_dataset.element_spec
|
||||||
|
|
||||||
|
|
||||||
@ -1953,19 +1952,3 @@ def _replace_per_replica_spec(spec, i):
|
|||||||
return spec._value_specs[i] # pylint: disable=protected-access
|
return spec._value_specs[i] # pylint: disable=protected-access
|
||||||
else:
|
else:
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def _enable_get_next_as_optional(strategy, element_spec):
|
|
||||||
"""Returns whether to enable using partial batch handling."""
|
|
||||||
# TODO(b/133073708): we currently need a flag to control the usage because
|
|
||||||
# there is a performance difference between get_next() and
|
|
||||||
# get_next_as_optional(). And we only enable get_next_as_optional when the
|
|
||||||
# output shapes are not static.
|
|
||||||
#
|
|
||||||
# TODO(rxsang): We want to always enable the get_next_as_optional behavior
|
|
||||||
# when user passed input_fn instead of dataset.
|
|
||||||
if not getattr(strategy.extended, "experimental_enable_get_next_as_optional",
|
|
||||||
False):
|
|
||||||
return False
|
|
||||||
return not _is_statically_shaped(
|
|
||||||
element_spec) or strategy.extended._in_multi_worker_mode() # pylint: disable=protected-access
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user