Add CompositeTensor support to distributed iterators.
PiperOrigin-RevId: 308423654 Change-Id: I406d7adeca6059112768f0b8a02a80fbbd463463
This commit is contained in:
		
							parent
							
								
									be3e3961f8
								
							
						
					
					
						commit
						dac6d6ae7c
					
				| @ -847,6 +847,38 @@ distribute_py_test( | ||||
|     srcs = ["input_lib_test.py"], | ||||
|     main = "input_lib_test.py", | ||||
|     shard_count = 10, | ||||
|     tags = [ | ||||
|         "multi_and_single_gpu", | ||||
|         "no_gpu_presubmit",  # TODO(b/154660040) | ||||
|     ], | ||||
|     deps = [ | ||||
|         ":collective_all_reduce_strategy", | ||||
|         ":combinations", | ||||
|         ":input_lib", | ||||
|         ":mirrored_strategy", | ||||
|         ":multi_worker_test_base", | ||||
|         ":reduce_util", | ||||
|         ":strategy_combinations", | ||||
|         ":values", | ||||
|         "//tensorflow/python:control_flow_ops", | ||||
|         "//tensorflow/python:errors", | ||||
|         "//tensorflow/python:math_ops", | ||||
|         "//tensorflow/python:sparse_ops", | ||||
|         "//tensorflow/python:sparse_tensor", | ||||
|         "//tensorflow/python/data/ops:dataset_ops", | ||||
|         "//tensorflow/python/eager:context", | ||||
|         "//tensorflow/python/eager:test", | ||||
|         "//tensorflow/python/ops/ragged:ragged_tensor", | ||||
|         "//third_party/py/numpy", | ||||
|         "@absl_py//absl/testing:parameterized", | ||||
|     ], | ||||
| ) | ||||
| 
 | ||||
| distribute_py_test( | ||||
|     name = "input_lib_type_spec_test", | ||||
|     srcs = ["input_lib_type_spec_test.py"], | ||||
|     main = "input_lib_type_spec_test.py", | ||||
|     shard_count = 10, | ||||
|     tags = [ | ||||
|         "multi_and_single_gpu", | ||||
|     ], | ||||
| @ -1453,9 +1485,10 @@ distribute_py_test( | ||||
|     name = "ctl_correctness_test", | ||||
|     srcs = ["ctl_correctness_test.py"], | ||||
|     main = "ctl_correctness_test.py", | ||||
|     shard_count = 10, | ||||
|     shard_count = 20, | ||||
|     tags = [ | ||||
|         "multi_and_single_gpu", | ||||
|         "no_gpu_presubmit",  # TODO(b/154660040) | ||||
|         "noguitar",  # b/140755528 | ||||
|     ], | ||||
|     deps = [ | ||||
|  | ||||
| @ -330,8 +330,9 @@ class CollectiveAllReduceExtended(mirrored_strategy.MirroredExtended): | ||||
|         communication=self._communication) | ||||
|     super(CollectiveAllReduceExtended, self)._initialize_single_worker( | ||||
|         local_devices) | ||||
|     host_device = device_util.get_host_for_device(self._worker_device) | ||||
|     self._input_workers = input_lib.InputWorkers( | ||||
|         [(self._worker_device, self.worker_devices)]) | ||||
|         [(host_device, self.worker_devices)]) | ||||
| 
 | ||||
|     # Add a default device so that ops without specified devices will not end up | ||||
|     # on other workers. | ||||
|  | ||||
| @ -33,6 +33,7 @@ from tensorflow.python.distribute import input_ops | ||||
| from tensorflow.python.distribute import reduce_util | ||||
| from tensorflow.python.distribute import values | ||||
| from tensorflow.python.eager import context | ||||
| from tensorflow.python.framework import composite_tensor | ||||
| from tensorflow.python.framework import constant_op | ||||
| from tensorflow.python.framework import device as tf_device | ||||
| from tensorflow.python.framework import dtypes | ||||
| @ -41,6 +42,7 @@ from tensorflow.python.framework import ops | ||||
| from tensorflow.python.framework import sparse_tensor | ||||
| from tensorflow.python.framework import tensor_shape | ||||
| from tensorflow.python.framework import tensor_util | ||||
| from tensorflow.python.framework import type_spec | ||||
| from tensorflow.python.ops import array_ops | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| @ -143,9 +145,10 @@ class InputWorkers(object): | ||||
|       worker_device_pairs: A sequence of pairs: | ||||
|         `(input device, a tuple of compute devices fed by that input device)`. | ||||
|     """ | ||||
|     self._input_worker_devices = tuple(d for d, _ in worker_device_pairs) | ||||
|     self._worker_device_pairs = worker_device_pairs | ||||
|     self._input_worker_devices = tuple(d for d, _ in self._worker_device_pairs) | ||||
|     self._fed_devices = tuple(tuple(device_util.canonicalize(d) for d in f) | ||||
|                               for _, f in worker_device_pairs) | ||||
|                               for _, f in self._worker_device_pairs) | ||||
| 
 | ||||
|   @property | ||||
|   def num_workers(self): | ||||
| @ -165,6 +168,12 @@ class InputWorkers(object): | ||||
|                             for i in range(len(devices))) | ||||
|     return "%s:{\n%s}" % (self.__class__.__name__, debug_repr) | ||||
| 
 | ||||
|   def serialize(self): | ||||
|     return self._worker_device_pairs | ||||
| 
 | ||||
|   def deserialize(self, worker_device_pairs): | ||||
|     return InputWorkers(worker_device_pairs) | ||||
| 
 | ||||
| 
 | ||||
| def _get_next_as_optional(iterator, strategy, name=None): | ||||
|   """Returns an empty dataset indicator and the next input from the iterator.""" | ||||
| @ -208,7 +217,7 @@ def _get_next_as_optional(iterator, strategy, name=None): | ||||
| 
 | ||||
| 
 | ||||
| def _is_statically_shaped(tensor_class, shape): | ||||
|   """Test if an iteratort output is statically shaped. | ||||
|   """Test if an iterator output is statically shaped. | ||||
| 
 | ||||
|   For sparse and ragged tensors this only tests the batch dimension. | ||||
| 
 | ||||
| @ -231,20 +240,27 @@ def _is_statically_shaped(tensor_class, shape): | ||||
|   return shape.is_fully_defined() | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIterator(object): | ||||
| def _get_static_shape(iterators): | ||||
|   """Returns a boolean indicating if the input is fully defined.""" | ||||
|   static_shape = True | ||||
|   for iterator in iterators: | ||||
|     if not isinstance(iterator, (_SingleWorkerOwnedDatasetIterator, | ||||
|                                  _SingleWorkerDatasetIterator)): | ||||
|       continue | ||||
|     flattened = zip(nest.flatten(iterator.output_shapes), | ||||
|                     nest.flatten(iterator.output_classes)) | ||||
|     for output_shape, output_class in flattened: | ||||
|       if not _is_statically_shaped(output_class, output_shape): | ||||
|         static_shape = False | ||||
|         break | ||||
|     return static_shape | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIteratorBase(object): | ||||
|   """Common implementation for all input iterators.""" | ||||
| 
 | ||||
|   def __init__(self, input_workers, iterators, strategy): | ||||
|     static_shape = True | ||||
|     for iterator in iterators: | ||||
|       if not isinstance(iterator, _SingleWorkerDatasetIterator): | ||||
|         continue | ||||
|       flattened = zip(nest.flatten(iterator.output_shapes), | ||||
|                       nest.flatten(iterator.output_classes)) | ||||
|       for output_shape, output_class in flattened: | ||||
|         if not _is_statically_shaped(output_class, output_shape): | ||||
|           static_shape = False | ||||
|           break | ||||
|     static_shape = _get_static_shape(iterators) | ||||
| 
 | ||||
|     # TODO(b/133073708): we currently need a flag to control the usage because | ||||
|     # there is a performance difference between get_next() and | ||||
| @ -360,6 +376,10 @@ class DistributedIterator(object): | ||||
| 
 | ||||
|     return values.regroup(replicas) | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIteratorV1(DistributedIteratorBase): | ||||
|   """Input Iterator for a distributed dataset.""" | ||||
| 
 | ||||
|   # We need a private initializer method for re-initializing multidevice | ||||
|   # iterators when used with Keras training loops. If we don't reinitialize the | ||||
|   # iterator we run into memory leak issues (b/123315763). | ||||
| @ -370,23 +390,14 @@ class DistributedIterator(object): | ||||
|       init_ops.extend(it.initialize()) | ||||
|     return control_flow_ops.group(init_ops) | ||||
| 
 | ||||
|   @property | ||||
|   def element_spec(self): | ||||
|     """The type specification of an element of this iterator.""" | ||||
|     return self._element_spec | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIteratorV1(DistributedIterator): | ||||
|   """Input Iterator for a distributed dataset instance.""" | ||||
| 
 | ||||
|   @deprecated(None, "Use the iterator's `initializer` property instead.") | ||||
|   def initialize(self): | ||||
|     """Initialze underlying iterators. | ||||
|     """Initialize underlying iterators. | ||||
| 
 | ||||
|     Returns: | ||||
|       A list of any initializer ops that should be run. | ||||
|     """ | ||||
|     return super(DistributedIteratorV1, self)._initializer | ||||
|     return self._initializer | ||||
| 
 | ||||
|   @property | ||||
|   def initializer(self): | ||||
| @ -415,6 +426,161 @@ class DistributedIteratorV1(DistributedIterator): | ||||
|         return self._iterators[i] | ||||
|     return None | ||||
| 
 | ||||
|   @property | ||||
|   def element_spec(self): | ||||
|     """The type specification of an element of this iterator.""" | ||||
|     return self._element_spec | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIteratorSpec(type_spec.TypeSpec): | ||||
|   """Type specification for `DistributedIterator`.""" | ||||
| 
 | ||||
|   __slots__ = ["_input_workers", "_element_spec", "_strategy"] | ||||
| 
 | ||||
|   def __init__(self, input_workers, element_spec, strategy): | ||||
|     # We don't want to allow deserialization of this class because we don't | ||||
|     # serialize the strategy object. Currently the only places where | ||||
|     # _deserialize is called is when we save/restore using SavedModels. | ||||
|     if isinstance(input_workers, tuple): | ||||
|       raise NotImplementedError("DistributedIteratorSpec does not have support " | ||||
|                                 "for deserialization.") | ||||
|     else: | ||||
|       self._input_workers = input_workers | ||||
|       self._element_spec = element_spec | ||||
|       self._strategy = strategy | ||||
| 
 | ||||
|   @property | ||||
|   def value_type(self): | ||||
|     return DistributedIterator | ||||
| 
 | ||||
|   def _serialize(self): | ||||
|     # We cannot serialize the strategy object so we convert it to an id that we | ||||
|     # can use for comparison. | ||||
|     return (self._input_workers.serialize(), | ||||
|             self._element_spec, id(self._strategy)) | ||||
| 
 | ||||
|   def _deserialize(self): | ||||
|     raise ValueError("Deserialization is currently unsupported for " | ||||
|                      "DistributedIteratorSpec.") | ||||
| 
 | ||||
|   @staticmethod | ||||
|   def _is_compatible(a, b): | ||||
|     """Returns true if the given type serializations compatible.""" | ||||
|     if type(a) is not type(b): | ||||
|       return False | ||||
|     if isinstance(a, tuple): | ||||
|       return (len(a) == len(b) and | ||||
|               all(DistributedIteratorSpec._is_compatible(x, y) for (x, y) in | ||||
|                   zip(a, b))) | ||||
|     if isinstance(a, dict): | ||||
|       return (len(a) == len(b) and sorted(a.keys()) == sorted(b.keys()) and all( | ||||
|           DistributedIteratorSpec._is_compatible(a[k], b[k]) for k in a.keys())) | ||||
|     if isinstance(a, (type_spec.TypeSpec, tensor_shape.TensorShape, | ||||
|                       dtypes.DType)): | ||||
|       return a.is_compatible_with(b) | ||||
|     return a == b | ||||
| 
 | ||||
|   # Overriding this method so that we can merge and reconstruct the spec object | ||||
|   def most_specific_compatible_type(self, other): | ||||
|     """Returns the most specific TypeSpec compatible with `self` and `other`. | ||||
| 
 | ||||
|     Args: | ||||
|       other: A `TypeSpec`. | ||||
| 
 | ||||
|     Raises: | ||||
|       ValueError: If there is no TypeSpec that is compatible with both `self` | ||||
|         and `other`. | ||||
|     """ | ||||
|     # pylint: disable=protected-access | ||||
|     if type(self) is not type(other): | ||||
|       raise ValueError("No TypeSpec is compatible with both %s and %s" % | ||||
|                        (self, other)) | ||||
|     if not self._is_compatible(self._input_workers.serialize(), | ||||
|                                other._input_workers.serialize()): | ||||
|       raise ValueError("_input_workers is not compatible with both %s " | ||||
|                        "and %s" % (self, other)) | ||||
|     if self._element_spec != other._element_spec: | ||||
|       raise ValueError("_element_spec is not compatible with both %s " | ||||
|                        "and %s" % (self, other)) | ||||
|     if id(self._strategy) != id(other._strategy): | ||||
|       raise ValueError("tf.distribute strategy is not compatible with both %s " | ||||
|                        "and %s" % (self, other)) | ||||
|     return DistributedIteratorSpec(self._input_workers, self._element_spec, | ||||
|                                    self._strategy) | ||||
| 
 | ||||
|   @property | ||||
|   def _component_specs(self): | ||||
|     specs = [] | ||||
|     worker_device_pairs = self._input_workers._worker_device_pairs  # pylint: disable=protected-access | ||||
|     for i in range(len(worker_device_pairs)): | ||||
|       input_device, compute_devices = worker_device_pairs[i] | ||||
|       specs.append(_SingleWorkerDatasetIteratorSpec(input_device, | ||||
|                                                     compute_devices, | ||||
|                                                     element_spec= | ||||
|                                                     self._element_spec)) | ||||
|     return specs | ||||
| 
 | ||||
|   def _to_components(self, value): | ||||
|     return value._iterators  # pylint: disable=protected-access | ||||
| 
 | ||||
|   def _from_components(self, components): | ||||
|     return DistributedIterator(input_workers=self._input_workers, | ||||
|                                iterators=None, | ||||
|                                components=components, | ||||
|                                element_spec=self._element_spec, | ||||
|                                strategy=self._strategy) | ||||
| 
 | ||||
|   @staticmethod | ||||
|   def from_value(value): | ||||
|     # pylint: disable=protected-access | ||||
|     return DistributedIteratorSpec(value._input_workers, value._element_spec, | ||||
|                                    value._strategy) | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIterator(DistributedIteratorBase, | ||||
|                           composite_tensor.CompositeTensor): | ||||
|   """Input Iterator for a distributed dataset.""" | ||||
| 
 | ||||
|   def __init__(self, input_workers=None, iterators=None, strategy=None, | ||||
|                components=None, element_spec=None): | ||||
|     if input_workers is None: | ||||
|       raise ValueError("`input_workers` should be " | ||||
|                        "provided.") | ||||
| 
 | ||||
|     error_message = ("Either `input_workers` or " | ||||
|                      "both `components` and `element_spec` need to be " | ||||
|                      "provided.") | ||||
| 
 | ||||
|     if iterators is None: | ||||
|       if (components is None or element_spec is None): | ||||
|         raise ValueError(error_message) | ||||
|       self._element_spec = element_spec | ||||
|       self._input_workers = input_workers | ||||
|       self._iterators = components | ||||
|       static_shape = _get_static_shape(self._iterators) | ||||
|       self._strategy = strategy | ||||
|       if getattr( | ||||
|           strategy.extended, "experimental_enable_get_next_as_optional", False): | ||||
|         self._enable_get_next_as_optional = not static_shape | ||||
|       else: | ||||
|         self._enable_get_next_as_optional = False | ||||
|     else: | ||||
|       if (components is not None and element_spec is not None): | ||||
|         raise ValueError(error_message) | ||||
| 
 | ||||
|       super(DistributedIterator, self).__init__(input_workers, iterators, | ||||
|                                                 strategy) | ||||
| 
 | ||||
|   @property | ||||
|   def element_spec(self): | ||||
|     return self._element_spec | ||||
| 
 | ||||
|   @property | ||||
|   def _type_spec(self): | ||||
|     return DistributedIteratorSpec(self._input_workers, | ||||
|                                    self.element_spec, | ||||
|                                    self._strategy) | ||||
| 
 | ||||
| 
 | ||||
| class _IterableInput(object): | ||||
|   """Base class for iterable inputs for distribution strategies.""" | ||||
| @ -482,7 +648,6 @@ class DistributedDataset(_IterableInput): | ||||
|         `num_input_pipelines` in the `InputContext`. | ||||
|     """ | ||||
|     super(DistributedDataset, self).__init__(input_workers=input_workers) | ||||
| 
 | ||||
|     # We clone and shard the dataset on each worker. The current setup tries to | ||||
|     # shard the dataset by files if possible so that each worker sees a | ||||
|     # different subset of files. If that is not possible, will attempt to shard | ||||
| @ -541,10 +706,20 @@ class DistributedDataset(_IterableInput): | ||||
|       raise RuntimeError("__iter__() is only supported inside of tf.function " | ||||
|                          "or when eager execution is enabled.") | ||||
| 
 | ||||
|     # This is an optional flag that can be used to turn off using | ||||
|     # OwnedMultiDeviceIterators and instead use the legacy MultiDeviceIterators | ||||
|     # as a stop gap solution that will allow us to roll out this change. | ||||
|     enable_legacy_iterators = getattr(self._strategy, | ||||
|                                       "_enable_legacy_iterators", False) | ||||
|     worker_iterators = _create_iterators_per_worker(self._cloned_datasets, | ||||
|                                                     self._input_workers) | ||||
|     iterator = DistributedIterator(self._input_workers, worker_iterators, | ||||
|                                    self._strategy) | ||||
|                                                     self._input_workers, | ||||
|                                                     enable_legacy_iterators) | ||||
|     if enable_legacy_iterators: | ||||
|       iterator = DistributedIteratorV1(self._input_workers, worker_iterators, | ||||
|                                        self._strategy) | ||||
|     else: | ||||
|       iterator = DistributedIterator(self._input_workers, worker_iterators, | ||||
|                                      self._strategy) | ||||
|     iterator._element_spec = self.element_spec  # pylint: disable=protected-access | ||||
|     return iterator | ||||
| 
 | ||||
| @ -615,12 +790,21 @@ class DistributedDatasetV1(DistributedDataset): | ||||
| 
 | ||||
|   def _get_iterator(self): | ||||
|     worker_iterators = _create_iterators_per_worker(self._cloned_datasets, | ||||
|                                                     self._input_workers) | ||||
|                                                     self._input_workers, | ||||
|                                                     True) | ||||
|     iterator = DistributedIteratorV1(self._input_workers, worker_iterators, | ||||
|                                      self._strategy) | ||||
|     iterator._element_spec = self.element_spec  # pylint: disable=protected-access | ||||
|     return iterator | ||||
| 
 | ||||
|   def __iter__(self): | ||||
|     if (ops.executing_eagerly_outside_functions() or | ||||
|         ops.get_default_graph().building_function): | ||||
|       return self._get_iterator() | ||||
| 
 | ||||
|     raise RuntimeError("__iter__() is only supported inside of tf.function " | ||||
|                        "or when eager execution is enabled.") | ||||
| 
 | ||||
| 
 | ||||
| # TODO(priyag): Add other replication modes. | ||||
| class DistributedDatasetsFromFunction(_IterableInput): | ||||
| @ -653,20 +837,36 @@ class DistributedDatasetsFromFunction(_IterableInput): | ||||
|     self._strategy = strategy | ||||
|     self._element_spec = None | ||||
| 
 | ||||
|   def __iter__(self): | ||||
|     if not (context.executing_eagerly() or | ||||
|             ops.get_default_graph().building_function): | ||||
|       raise RuntimeError("__iter__() is only supported inside of tf.function " | ||||
|                          "or when eager execution is enabled.") | ||||
|     super(DistributedDatasetsFromFunction, self).__init__( | ||||
|         input_workers=input_workers) | ||||
| 
 | ||||
|     iterators, element_spec = _create_iterators_per_worker_with_input_context( | ||||
|         self._input_contexts, self._input_workers, self._dataset_fn) | ||||
|     iterator = DistributedIterator(self._input_workers, iterators, | ||||
|                                    self._strategy) | ||||
|     self._element_spec = _create_distributed_tensor_spec(self._strategy, | ||||
|                                                          element_spec) | ||||
|     iterator._element_spec = self._element_spec  # pylint: disable=protected-access | ||||
|     return iterator | ||||
|   def __iter__(self): | ||||
|     if (ops.executing_eagerly_outside_functions() or | ||||
|         ops.get_default_graph().building_function): | ||||
|       # This is an optional flag that can be used to turn off using | ||||
|       # OwnedMultiDeviceIterators and instead use the legacy | ||||
|       # MultiDeviceIterators as a stop gap solution that will allow us to roll | ||||
|       # out this change. | ||||
|       enable_legacy_iterators = getattr(self._strategy, | ||||
|                                         "_enable_legacy_iterators", False) | ||||
| 
 | ||||
|       iterators, element_spec = _create_iterators_per_worker_with_input_context( | ||||
|           self._input_contexts, self._input_workers, self._dataset_fn, | ||||
|           enable_legacy_iterators) | ||||
| 
 | ||||
|       if enable_legacy_iterators: | ||||
|         iterator = DistributedIteratorV1(self._input_workers, iterators, | ||||
|                                          self._strategy) | ||||
|       else: | ||||
|         iterator = DistributedIterator(self._input_workers, iterators, | ||||
|                                        self._strategy) | ||||
|       self._element_spec = _create_distributed_tensor_spec(self._strategy, | ||||
|                                                            element_spec) | ||||
|       iterator._element_spec = self._element_spec  # pylint: disable=protected-access | ||||
|       return iterator | ||||
| 
 | ||||
|     raise RuntimeError("__iter__() is only supported inside of tf.function " | ||||
|                        "or when eager execution is enabled.") | ||||
| 
 | ||||
|   @property | ||||
|   def element_spec(self): | ||||
| @ -705,7 +905,8 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): | ||||
| 
 | ||||
|   def _get_iterator(self): | ||||
|     iterators, element_spec = _create_iterators_per_worker_with_input_context( | ||||
|         self._input_contexts, self._input_workers, self._dataset_fn) | ||||
|         self._input_contexts, self._input_workers, self._dataset_fn, | ||||
|         True) | ||||
|     iterator = DistributedIteratorV1(self._input_workers, iterators, | ||||
|                                      self._strategy) | ||||
|     self._element_spec = _create_distributed_tensor_spec(self._strategy, | ||||
| @ -713,6 +914,14 @@ class DistributedDatasetsFromFunctionV1(DistributedDatasetsFromFunction): | ||||
|     iterator._element_spec = self._element_spec  # pylint: disable=protected-access | ||||
|     return iterator | ||||
| 
 | ||||
|   def __iter__(self): | ||||
|     if (ops.executing_eagerly_outside_functions() or | ||||
|         ops.get_default_graph().building_function): | ||||
|       return self._get_iterator() | ||||
| 
 | ||||
|     raise RuntimeError("__iter__() is only supported inside of tf.function " | ||||
|                        "or when eager execution is enabled.") | ||||
| 
 | ||||
| 
 | ||||
| # TODO(anjalisridhar): This class will be soon be removed in favor of newer | ||||
| # APIs. | ||||
| @ -797,7 +1006,7 @@ class DatasetIterator(DistributedIteratorV1): | ||||
|         split_batch_by=split_batch_by, | ||||
|         input_context=input_context) | ||||
|     worker_iterators = _create_iterators_per_worker( | ||||
|         dist_dataset._cloned_datasets, input_workers)  # pylint: disable=protected-access | ||||
|         dist_dataset._cloned_datasets, input_workers, True)  # pylint: disable=protected-access | ||||
|     super(DatasetIterator, self).__init__( | ||||
|         input_workers, | ||||
|         worker_iterators,  # pylint: disable=protected-access | ||||
| @ -808,18 +1017,18 @@ class DatasetIterator(DistributedIteratorV1): | ||||
| def _dummy_tensor_fn(value_structure): | ||||
|   """A function to create dummy tensors from `value_structure`.""" | ||||
| 
 | ||||
|   def create_dummy_tensor(type_spec): | ||||
|   def create_dummy_tensor(spec): | ||||
|     """Create a dummy tensor with possible batch dimensions set to 0.""" | ||||
|     if isinstance(type_spec, ragged_tensor.RaggedTensorSpec): | ||||
|     if isinstance(spec, ragged_tensor.RaggedTensorSpec): | ||||
|       # Splice out the ragged dimensions. | ||||
|       # pylint: disable=protected-access | ||||
|       feature_shape = type_spec._shape[:1].concatenate( | ||||
|           type_spec._shape[(1 + type_spec._ragged_rank):]) | ||||
|       feature_type = type_spec._dtype | ||||
|       feature_shape = spec._shape[:1].concatenate( | ||||
|           spec._shape[(1 + spec._ragged_rank):]) | ||||
|       feature_type = spec._dtype | ||||
|       # pylint: enable=protected-access | ||||
|     else: | ||||
|       feature_shape = type_spec.shape | ||||
|       feature_type = type_spec.dtype | ||||
|       feature_shape = spec.shape | ||||
|       feature_type = spec.dtype | ||||
|     # Ideally we should set the batch dimension to 0, however as in | ||||
|     # DistributionStrategy we don't know the batch dimension, we try to | ||||
|     # guess it as much as possible. If the feature has unknown dimensions, we | ||||
| @ -827,11 +1036,11 @@ def _dummy_tensor_fn(value_structure): | ||||
|     # first dimension as batch dimension and set it to 0. | ||||
|     dims = ([dim if dim is not None else 0 for dim in feature_shape.as_list()] | ||||
|             if feature_shape else []) | ||||
|     if dims and (isinstance(type_spec, ragged_tensor.RaggedTensorSpec) or | ||||
|     if dims and (isinstance(spec, ragged_tensor.RaggedTensorSpec) or | ||||
|                  feature_shape.is_fully_defined()): | ||||
|       dims[0] = tensor_shape.Dimension(0) | ||||
| 
 | ||||
|     if isinstance(type_spec, sparse_tensor.SparseTensorSpec): | ||||
|     if isinstance(spec, sparse_tensor.SparseTensorSpec): | ||||
|       return sparse_tensor.SparseTensor( | ||||
|           values=array_ops.zeros(0, feature_type), | ||||
|           indices=array_ops.zeros((0, len(dims)), dtypes.int64), | ||||
| @ -839,26 +1048,26 @@ def _dummy_tensor_fn(value_structure): | ||||
| 
 | ||||
|     # Create the dummy tensor. | ||||
|     dummy_tensor = array_ops.zeros(tensor_shape.TensorShape(dims), feature_type) | ||||
|     if isinstance(type_spec, ragged_tensor.RaggedTensorSpec): | ||||
|     if isinstance(spec, ragged_tensor.RaggedTensorSpec): | ||||
|       # Reinsert the ragged dimensions with size 0. | ||||
|       # pylint: disable=protected-access | ||||
|       row_splits = array_ops.zeros(1, type_spec._row_splits_dtype) | ||||
|       row_splits = array_ops.zeros(1, spec._row_splits_dtype) | ||||
|       dummy_tensor = ragged_tensor.RaggedTensor.from_nested_row_splits( | ||||
|           dummy_tensor, (row_splits,) * type_spec._ragged_rank, validate=False) | ||||
|           dummy_tensor, (row_splits,) * spec._ragged_rank, validate=False) | ||||
|       # pylint: enable=protected-access | ||||
|     return dummy_tensor | ||||
| 
 | ||||
|   return nest.map_structure(create_dummy_tensor, value_structure) | ||||
| 
 | ||||
| 
 | ||||
| class _SingleWorkerDatasetIterator(object): | ||||
| class _SingleWorkerDatasetIteratorBase(object): | ||||
|   """Iterator for a single `tf.data.Dataset`.""" | ||||
| 
 | ||||
|   def __init__(self, dataset, worker, devices): | ||||
|     """Create iterator for the `dataset` to fetch data to worker's `devices` . | ||||
| 
 | ||||
|     `MultiDeviceIterator` is used to prefetch input to the devices on the | ||||
|     given worker. | ||||
|     A `MultiDeviceIterator`  or `OwnedMultiDeviceIterator` is used to prefetch | ||||
|     input to the devices on the given worker. | ||||
| 
 | ||||
|     Args: | ||||
|       dataset: A `tf.data.Dataset` instance. | ||||
| @ -868,13 +1077,11 @@ class _SingleWorkerDatasetIterator(object): | ||||
|     self._dataset = dataset | ||||
|     self._worker = worker | ||||
|     self._devices = devices | ||||
|     self._element_spec = dataset.element_spec | ||||
|     self._make_iterator() | ||||
| 
 | ||||
|   def _make_iterator(self): | ||||
|     """Make appropriate iterator on the dataset.""" | ||||
|     with ops.device(self._worker): | ||||
|       self._iterator = multi_device_iterator_ops.MultiDeviceIterator( | ||||
|           self._dataset, self._devices) | ||||
|     raise NotImplementedError("must be implemented in descendants") | ||||
| 
 | ||||
|   def get_next(self, device, name=None): | ||||
|     """Get next element for the given device.""" | ||||
| @ -923,9 +1130,9 @@ class _SingleWorkerDatasetIterator(object): | ||||
|         # Place the condition op in the same device as the data so the data | ||||
|         # doesn't need to be sent back to the worker. | ||||
|         with ops.device(self._devices[i]): | ||||
|           # As MultiDeviceIterator will fetch data in order, so we only need to | ||||
|           # check if the first replica has value to see whether there is data | ||||
|           # left for this single worker. | ||||
|           # Data will be fetched in order, so we only need to check if the first | ||||
|           # replica has value to see whether there is data left for this single | ||||
|           # worker. | ||||
|           if i == 0: | ||||
|             worker_has_value = data.has_value() | ||||
| 
 | ||||
| @ -943,8 +1150,159 @@ class _SingleWorkerDatasetIterator(object): | ||||
| 
 | ||||
|       return worker_has_value, result | ||||
| 
 | ||||
| 
 | ||||
| class _SingleWorkerDatasetIteratorSpec(type_spec.TypeSpec): | ||||
|   """Type specification for `_SingleWorkerOwnedDatasetIterator`.""" | ||||
| 
 | ||||
|   __slots__ = ["_worker", "_devices", "_element_spec"] | ||||
| 
 | ||||
|   def __init__(self, worker, devices, element_spec): | ||||
|     self._worker = worker | ||||
|     self._devices = devices | ||||
|     self._element_spec = element_spec | ||||
| 
 | ||||
|   @property | ||||
|   def value_type(self): | ||||
|     return _SingleWorkerOwnedDatasetIterator | ||||
| 
 | ||||
|   def _serialize(self): | ||||
|     return (self._worker, tuple(self._devices), self._element_spec) | ||||
| 
 | ||||
|   @property | ||||
|   def _component_specs(self): | ||||
|     specs = [] | ||||
|     specs.append(multi_device_iterator_ops.MultiDeviceIteratorSpec( | ||||
|         self._devices, self._worker, element_spec=self._element_spec)) | ||||
|     return specs | ||||
| 
 | ||||
|   def _to_components(self, value): | ||||
|     return [value._iterator]  # pylint: disable=protected-access | ||||
| 
 | ||||
|   def _from_components(self, components): | ||||
|     return _SingleWorkerOwnedDatasetIterator( | ||||
|         dataset=None, | ||||
|         worker=self._worker, | ||||
|         devices=self._devices, | ||||
|         components=components, | ||||
|         element_spec=self._element_spec) | ||||
| 
 | ||||
|   @staticmethod | ||||
|   def from_value(value): | ||||
|     # pylint: disable=protected-access | ||||
|     return _SingleWorkerDatasetIteratorSpec(value._worker, value._devices, | ||||
|                                             value._element_spec) | ||||
| 
 | ||||
| 
 | ||||
| class _SingleWorkerOwnedDatasetIterator(_SingleWorkerDatasetIteratorBase, | ||||
|                                         composite_tensor.CompositeTensor): | ||||
|   """Iterator for a DistributedDataset instance.""" | ||||
| 
 | ||||
|   def __init__(self, dataset=None, worker=None, devices=None, components=None, | ||||
|                element_spec=None): | ||||
|     """Create iterator for the `dataset` to fetch data to worker's `devices` . | ||||
| 
 | ||||
|     `OwnedMultiDeviceIterator` is used to prefetch input to the devices on the | ||||
|     given worker. The lifetime of this iterator is tied to the encompassing | ||||
|     python object. Once we go out of scope of the python object or return from | ||||
|     a tf.function the underlying iterator resource is deleted. | ||||
| 
 | ||||
|     Args: | ||||
|       dataset: A `tf.data.Dataset` instance. | ||||
|       worker: Worker on which ops should be created. | ||||
|       devices: Distribute data from `dataset` to these devices. | ||||
|       components: Tensor components to construct the | ||||
|         _SingleWorkerOwnedDatasetIterator from. | ||||
|       element_spec: A nested structure of `TypeSpec` objects that represents the | ||||
|       type specification of elements of the iterator. | ||||
|     """ | ||||
|     if worker is None or devices is None: | ||||
|       raise ValueError("Both `worker` and `devices` should be provided") | ||||
| 
 | ||||
|     error_message = ("Either `dataset` or both `components` and `element_spec` " | ||||
|                      "need to be provided.") | ||||
| 
 | ||||
|     if dataset is None: | ||||
|       if (components is None or element_spec is None): | ||||
|         raise ValueError(error_message) | ||||
|       self._element_spec = element_spec | ||||
|       self._worker = worker | ||||
|       self._devices = devices | ||||
|       self._iterator = components[0] | ||||
|     else: | ||||
|       if (components is not None or element_spec is not None): | ||||
|         raise ValueError(error_message) | ||||
|       super(_SingleWorkerOwnedDatasetIterator, self).__init__(dataset, worker, | ||||
|                                                               devices) | ||||
| 
 | ||||
|   def _make_iterator(self): | ||||
|     """Make appropriate iterator on the dataset.""" | ||||
|     if not self._worker: | ||||
|       raise ValueError("Worked device must be specified when creating an " | ||||
|                        "owned iterator.") | ||||
|     host_device = device_util.get_host_for_device(self._worker) | ||||
|     with ops.device(self._worker): | ||||
|       self._iterator = multi_device_iterator_ops.OwnedMultiDeviceIterator( | ||||
|           self._dataset, self._devices, source_device=host_device) | ||||
| 
 | ||||
|   @property | ||||
|   def element_spec(self): | ||||
|     return self._element_spec | ||||
| 
 | ||||
|   @property | ||||
|   def _type_spec(self): | ||||
|     return _SingleWorkerDatasetIteratorSpec(self._worker, self._devices, | ||||
|                                             self._element_spec) | ||||
| 
 | ||||
|   @property | ||||
|   def output_classes(self): | ||||
|     """Returns the class of each component of an element of this iterator. | ||||
| 
 | ||||
|     The expected values are `tf.Tensor` and `tf.SparseTensor`. | ||||
| 
 | ||||
|     Returns: | ||||
|       A nested structure of Python `type` objects corresponding to each | ||||
|       component of an element of this dataset. | ||||
|     """ | ||||
|     return nest.map_structure( | ||||
|         lambda component_spec: component_spec._to_legacy_output_classes(),  # pylint: disable=protected-access | ||||
|         self._element_spec) | ||||
| 
 | ||||
|   @property | ||||
|   def output_shapes(self): | ||||
|     """Returns the shape of each component of an element of this iterator. | ||||
| 
 | ||||
|     Returns: | ||||
|       A nested structure of `tf.TensorShape` objects corresponding to each | ||||
|       component of an element of this dataset. | ||||
|     """ | ||||
|     return nest.map_structure( | ||||
|         lambda component_spec: component_spec._to_legacy_output_shapes(),  # pylint: disable=protected-access | ||||
|         self._element_spec) | ||||
| 
 | ||||
|   @property | ||||
|   def output_types(self): | ||||
|     """Returns the type of each component of an element of this iterator. | ||||
| 
 | ||||
|     Returns: | ||||
|       A nested structure of `tf.DType` objects corresponding to each component | ||||
|       of an element of this dataset. | ||||
|     """ | ||||
|     return nest.map_structure( | ||||
|         lambda component_spec: component_spec._to_legacy_output_types(),  # pylint: disable=protected-access | ||||
|         self._element_spec) | ||||
| 
 | ||||
| 
 | ||||
| class _SingleWorkerDatasetIterator(_SingleWorkerDatasetIteratorBase): | ||||
|   """Iterator for a single DistributedDatasetV1 instance.""" | ||||
| 
 | ||||
|   def _make_iterator(self): | ||||
|     """Make appropriate iterator on the dataset.""" | ||||
|     with ops.device(self._worker): | ||||
|       self._iterator = multi_device_iterator_ops.MultiDeviceIterator( | ||||
|           self._dataset, self._devices) | ||||
| 
 | ||||
|   def initialize(self): | ||||
|     """Initialze underlying iterator. | ||||
|     """Initialize underlying iterator. | ||||
| 
 | ||||
|     In eager execution, this simply recreates the underlying iterator. | ||||
|     In graph execution, it returns the initializer ops for the underlying | ||||
| @ -1005,7 +1363,8 @@ class _SingleWorkerCallableIterator(object): | ||||
|     return [] | ||||
| 
 | ||||
| 
 | ||||
| def _create_iterators_per_worker(worker_datasets, input_workers): | ||||
| def _create_iterators_per_worker(worker_datasets, input_workers, | ||||
|                                  enable_legacy_iterators): | ||||
|   """Create a multidevice iterator on each of the workers.""" | ||||
|   assert isinstance(input_workers, InputWorkers) | ||||
| 
 | ||||
| @ -1014,23 +1373,35 @@ def _create_iterators_per_worker(worker_datasets, input_workers): | ||||
|   for i, worker in enumerate(input_workers.worker_devices): | ||||
|     with ops.device(worker): | ||||
|       worker_devices = input_workers.compute_devices_for_worker(i) | ||||
|       iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, | ||||
|                                               worker_devices) | ||||
|       if tf2.enabled() and not enable_legacy_iterators: | ||||
|         iterator = _SingleWorkerOwnedDatasetIterator(worker_datasets[i], worker, | ||||
|                                                      worker_devices) | ||||
|       else: | ||||
|         iterator = _SingleWorkerDatasetIterator(worker_datasets[i], worker, | ||||
|                                                 worker_devices) | ||||
|       iterators.append(iterator) | ||||
|   return iterators | ||||
| 
 | ||||
| 
 | ||||
| def _create_iterators_per_worker_with_input_context(input_contexts, | ||||
|                                                     input_workers, | ||||
|                                                     dataset_fn): | ||||
|                                                     dataset_fn, | ||||
|                                                     enable_legacy_iterators): | ||||
|   """Create a multidevice iterator per workers given a dataset function.""" | ||||
|   iterators = [] | ||||
|   element_specs = [] | ||||
|   for i, ctx in enumerate(input_contexts): | ||||
|     worker = input_workers.worker_devices[i] | ||||
|     with ops.device(worker): | ||||
|       dataset = dataset_fn(ctx) | ||||
|       element_specs.append(dataset.element_spec) | ||||
|       devices = input_workers.compute_devices_for_worker(i) | ||||
|       iterator = _SingleWorkerDatasetIterator(dataset, worker, devices) | ||||
|       if tf2.enabled() and not enable_legacy_iterators: | ||||
|         iterator = _SingleWorkerOwnedDatasetIterator(dataset, worker, | ||||
|                                                      devices) | ||||
|       else: | ||||
|         iterator = _SingleWorkerDatasetIterator(dataset, worker, | ||||
|                                                 devices) | ||||
|       iterators.append(iterator) | ||||
|   return iterators, dataset.element_spec | ||||
| 
 | ||||
|  | ||||
| @ -45,6 +45,7 @@ from tensorflow.python.eager import context | ||||
| from tensorflow.python.eager import def_function | ||||
| from tensorflow.python.eager import test | ||||
| from tensorflow.python.framework import errors | ||||
| from tensorflow.python.framework import ops | ||||
| from tensorflow.python.framework import sparse_tensor | ||||
| from tensorflow.python.ops import control_flow_ops | ||||
| from tensorflow.python.ops import math_ops | ||||
| @ -105,20 +106,21 @@ class DistributedIteratorTestBase(test.TestCase): | ||||
|                     split_batch_by, | ||||
|                     strategy, | ||||
|                     input_context=None): | ||||
|     if isinstance(dataset, (dataset_ops.Dataset, dataset_ops.DatasetV1Adapter)): | ||||
|       return input_lib.DistributedDatasetV1( | ||||
|           dataset, | ||||
|           input_workers, | ||||
|           strategy, | ||||
|           split_batch_by=split_batch_by, | ||||
|           input_context=input_context) | ||||
|     elif input_type == "dataset": | ||||
|       return input_lib.DistributedDataset( | ||||
|           dataset, | ||||
|           input_workers, | ||||
|           strategy, | ||||
|           split_batch_by=split_batch_by, | ||||
|           input_context=input_context) | ||||
|     if input_type == "dataset": | ||||
|       if tf2.enabled(): | ||||
|         return input_lib.DistributedDataset( | ||||
|             dataset, | ||||
|             input_workers, | ||||
|             strategy, | ||||
|             split_batch_by=split_batch_by, | ||||
|             input_context=input_context) | ||||
|       else: | ||||
|         return input_lib.DistributedDatasetV1( | ||||
|             dataset, | ||||
|             input_workers, | ||||
|             strategy, | ||||
|             split_batch_by=split_batch_by, | ||||
|             input_context=input_context) | ||||
|     else: | ||||
|       return strategy.experimental_distribute_datasets_from_function(dataset) | ||||
| 
 | ||||
| @ -139,6 +141,9 @@ class DistributedIteratorTestBase(test.TestCase): | ||||
|     if api_type == "wrap_into_iterator" and iteration_type == "for_loop": | ||||
|       self.skipTest("unsupported test combination.") | ||||
| 
 | ||||
|     if api_type == "wrap_into_iterator" and input_type == "input_fn": | ||||
|       self.skipTest("unsupported test combination.") | ||||
| 
 | ||||
|     devices = nest.flatten([ds for _, ds in worker_device_pairs]) | ||||
|     input_workers = input_lib.InputWorkers(worker_device_pairs) | ||||
| 
 | ||||
| @ -161,7 +166,7 @@ class DistributedIteratorTestBase(test.TestCase): | ||||
|           strategy, | ||||
|           input_context=input_context) | ||||
| 
 | ||||
|       if context.executing_eagerly(): | ||||
|       if ops.executing_eagerly_outside_functions(): | ||||
|         iterator = iter(dataset) | ||||
|       else: | ||||
|         if isinstance(dataset, input_lib.DistributedDatasetV1): | ||||
| @ -171,7 +176,7 @@ class DistributedIteratorTestBase(test.TestCase): | ||||
| 
 | ||||
|     if iteration_type == "get_next": | ||||
|       evaluate = lambda x: sess.run(x) if sess else self.evaluate(x) | ||||
|       if isinstance(iterator, input_lib.DistributedIteratorV1): | ||||
|       if not ops.executing_eagerly_outside_functions(): | ||||
|         evaluate(control_flow_ops.group(iterator.initializer)) | ||||
| 
 | ||||
|       for expected_value in expected_values: | ||||
| @ -190,10 +195,13 @@ class DistributedIteratorTestBase(test.TestCase): | ||||
|                                    next_element) for r in range(len(devices))]) | ||||
| 
 | ||||
|       # After re-initializing the iterator, should be able to iterate again. | ||||
|       if isinstance(iterator, input_lib.DistributedIteratorV1): | ||||
|       if not ops.executing_eagerly_outside_functions(): | ||||
|         evaluate(control_flow_ops.group(iterator.initializer)) | ||||
|       else: | ||||
|         evaluate(control_flow_ops.group(iterator._initializer)) | ||||
|         if api_type == "wrap_into_iterator": | ||||
|           self.skipTest("unsupported test combination") | ||||
|         else: | ||||
|           iterator = iter(dataset) | ||||
| 
 | ||||
|       for expected_value in expected_values: | ||||
|         next_element = iterator.get_next() | ||||
| @ -225,6 +233,48 @@ class DistributedIteratorTestBase(test.TestCase): | ||||
| class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|                                           parameterized.TestCase): | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           input_type=["input_fn", "dataset"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.one_device_strategy, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu, | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu | ||||
|           ])) | ||||
|   def testDisablingOwnedIteratorsInTF2(self, distribution, input_type): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("unsupported test combination") | ||||
| 
 | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] | ||||
|     input_workers = input_lib.InputWorkers(worker_device_pairs) | ||||
|     dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) | ||||
|     dataset_or_input_fn = self._create_dataset_or_input_fn( | ||||
|         input_type, dataset_fn) | ||||
| 
 | ||||
|     input_workers = input_lib.InputWorkers(worker_device_pairs) | ||||
|     if input_type == "dataset": | ||||
|       dist_dataset = input_lib.get_distributed_dataset(dataset_or_input_fn, | ||||
|                                                        input_workers, | ||||
|                                                        distribution) | ||||
|     else: | ||||
|       dist_dataset = input_lib.get_distributed_datasets_from_function( | ||||
|           dataset_or_input_fn, input_workers, [distribute_lib.InputContext()], | ||||
|           distribution) | ||||
| 
 | ||||
|     # Default Iterator types in TF2. | ||||
|     iterator = iter(dist_dataset) | ||||
|     self.assertIsInstance(iterator, input_lib.DistributedIterator) | ||||
|     self.assertIsInstance(iterator._iterators[0], | ||||
|                           input_lib._SingleWorkerOwnedDatasetIterator) | ||||
| 
 | ||||
|     # Disable creating owned iterators by setting a property on the strategy. | ||||
|     distribution._enable_legacy_iterators = True | ||||
|     iterator = iter(dist_dataset) | ||||
|     self.assertIsInstance(iterator, input_lib.DistributedIteratorV1) | ||||
|     self.assertIsInstance(iterator._iterators[0], | ||||
|                           input_lib._SingleWorkerDatasetIterator) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
| @ -234,7 +284,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|   def testMultiDeviceIterInitialize(self, distribution): | ||||
|     if tf2.enabled(): | ||||
|       self.skipTest("Only V1 is supported.") | ||||
|     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", | ||||
|                                               "/device:CPU:0"])] | ||||
|     dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) | ||||
| 
 | ||||
|     input_workers = input_lib.InputWorkers(worker_device_pairs) | ||||
| @ -250,25 +301,6 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
| 
 | ||||
|     init_func_for_iter() | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["graph"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.one_device_strategy, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu | ||||
|           ])) | ||||
|   def testDatasetV2IterError(self, distribution): | ||||
|     worker_device_pairs = [("", ["/device:CPU:0"])] | ||||
|     input_workers = input_lib.InputWorkers(worker_device_pairs) | ||||
|     dataset_fn = lambda _: dataset_ops.DatasetV2.range(10).batch(2) | ||||
| 
 | ||||
|     dist_dataset = input_lib.get_distributed_dataset( | ||||
|         dataset_fn(distribute_lib.InputContext()), input_workers, distribution) | ||||
| 
 | ||||
|     with self.assertRaisesRegexp(RuntimeError, | ||||
|                                  "or when eager execution is enabled"): | ||||
|       iter(dist_dataset) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["graph", "eager"], | ||||
| @ -282,11 +314,11 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testOneDeviceCPU(self, input_type, api_type, iteration_type, distribution, | ||||
|                        enable_get_next_as_optional): | ||||
|     worker_device_pairs = [("", ["/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] | ||||
|     if tf2.enabled(): | ||||
|       dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) | ||||
|     else: | ||||
|       dataset_fn = lambda _: dataset_ops.Dataset.range(10) | ||||
|       dataset_fn = lambda _: dataset_ops.DatasetV1.range(10) | ||||
|     dataset_or_input_fn = self._create_dataset_or_input_fn( | ||||
|         input_type, dataset_fn) | ||||
| 
 | ||||
| @ -316,7 +348,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testTwoDevicesOneGPUOneCPU(self, input_type, api_type, iteration_type, | ||||
|                                  distribution, enable_get_next_as_optional): | ||||
|     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", | ||||
|                                               "/device:CPU:0"])] | ||||
|     if tf2.enabled(): | ||||
|       dataset_fn = lambda _: dataset_ops.DatasetV2.range(10) | ||||
|     else: | ||||
| @ -386,7 +419,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testTupleDataset(self, input_type, api_type, iteration_type, distribution, | ||||
|                        enable_get_next_as_optional): | ||||
|     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", | ||||
|                                               "/device:CPU:0"])] | ||||
| 
 | ||||
|     def dataset_fn(ctx): | ||||
|       del ctx | ||||
| @ -422,7 +456,7 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu | ||||
|           ])) | ||||
|   def testIterableIterator(self, distribution): | ||||
|     worker_device_pairs = [("", ["/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:CPU:0"])] | ||||
|     input_workers = input_lib.InputWorkers(worker_device_pairs) | ||||
| 
 | ||||
|     dataset = dataset_ops.DatasetV2.range(10) | ||||
| @ -446,7 +480,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|           ])) | ||||
|   def testUnevenDatasetBatches(self, input_type, api_type, iteration_type, | ||||
|                                drop_remainder, distribution): | ||||
|     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", | ||||
|                                               "/device:CPU:0"])] | ||||
|     if tf2.enabled(): | ||||
|       dataset_fn = lambda _: dataset_ops.DatasetV2.range(9).batch(  # pylint: disable=g-long-lambda | ||||
|           2, drop_remainder=drop_remainder) | ||||
| @ -486,7 +521,8 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, | ||||
|   def testBatchSplitting(self, input_type, api_type, iteration_type, | ||||
|                          split_batch_by, distribution, | ||||
|                          enable_get_next_as_optional): | ||||
|     worker_device_pairs = [("", ["/device:GPU:0", "/device:CPU:0"])] | ||||
|     worker_device_pairs = [("/device:CPU:0", ["/device:GPU:0", | ||||
|                                               "/device:CPU:0"])] | ||||
|     batch_size = 10 | ||||
|     if tf2.enabled(): | ||||
|       dataset_fn = lambda _: dataset_ops.DatasetV2.range(100).batch(batch_size) | ||||
| @ -1075,68 +1111,5 @@ class DistributedIteratorMultiWorkerTest( | ||||
|           strategy, | ||||
|           sess=sess) | ||||
| 
 | ||||
| 
 | ||||
| class InputTypeSpecTest(test.TestCase, parameterized.TestCase): | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.one_device_strategy, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu, | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|               strategy_combinations.central_storage_strategy_with_two_gpus, | ||||
|           ], | ||||
|           input_type=["dataset", "dataset_fn"], | ||||
|       )) | ||||
|   def testInputSignatureForPerReplicaValues(self, distribution, input_type): | ||||
|     def dataset_fn(ctx): | ||||
|       del ctx  # unused | ||||
|       return dataset_ops.DatasetV2.from_tensor_slices( | ||||
|           np.ones([10, 12]).astype(np.float32)).batch(4) | ||||
| 
 | ||||
|     if input_type == "dataset": | ||||
|       ds = distribution.experimental_distribute_dataset( | ||||
|           dataset_fn(distribute_lib.InputContext())) | ||||
|       type_spec = ds.element_spec | ||||
|     else: | ||||
|       ds = distribution.experimental_distribute_datasets_from_function( | ||||
|           dataset_fn) | ||||
|       iterator = iter(ds) | ||||
|       type_spec = iterator.element_spec | ||||
| 
 | ||||
|     @def_function.function(input_signature=[type_spec]) | ||||
|     def process_inputs(inputs): | ||||
|       distribution.run(lambda inputs: inputs, args=(inputs,)) | ||||
| 
 | ||||
|     for x in ds: | ||||
|       process_inputs(x) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.one_device_strategy, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu, | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|               strategy_combinations.central_storage_strategy_with_two_gpus, | ||||
|           ], | ||||
|       )) | ||||
|   def testInputSignatureForNestedPerReplicaValues(self, distribution): | ||||
|     a = np.ones((10, 2)) * 5 | ||||
|     b = np.ones((10, 3)) * 6 | ||||
|     dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
| 
 | ||||
|     @def_function.function(input_signature=[dist_dataset.element_spec]) | ||||
|     def process_inputs(inputs): | ||||
|       distribution.run(lambda inputs: inputs, args=(inputs,)) | ||||
| 
 | ||||
|     for x in dist_dataset: | ||||
|       process_inputs(x) | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
|  | ||||
							
								
								
									
										366
									
								
								tensorflow/python/distribute/input_lib_type_spec_test.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										366
									
								
								tensorflow/python/distribute/input_lib_type_spec_test.py
									
									
									
									
									
										Normal file
									
								
							| @ -0,0 +1,366 @@ | ||||
| # Copyright 2018 The TensorFlow Authors. All Rights Reserved. | ||||
| # | ||||
| # Licensed under the Apache License, Version 2.0 (the "License"); | ||||
| # you may not use this file except in compliance with the License. | ||||
| # You may obtain a copy of the License at | ||||
| # | ||||
| #     http://www.apache.org/licenses/LICENSE-2.0 | ||||
| # | ||||
| # Unless required by applicable law or agreed to in writing, software | ||||
| # distributed under the License is distributed on an "AS IS" BASIS, | ||||
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||||
| # See the License for the specific language governing permissions and | ||||
| # limitations under the License. | ||||
| # ============================================================================== | ||||
| """Tests for the input_lib library which tests iterator type specs.""" | ||||
| 
 | ||||
| from __future__ import absolute_import | ||||
| from __future__ import division | ||||
| from __future__ import print_function | ||||
| 
 | ||||
| 
 | ||||
| from absl.testing import parameterized | ||||
| import numpy as np | ||||
| 
 | ||||
| from tensorflow.python import tf2 | ||||
| from tensorflow.python.data.ops import dataset_ops | ||||
| from tensorflow.python.distribute import combinations | ||||
| from tensorflow.python.distribute import distribute_lib | ||||
| from tensorflow.python.distribute import strategy_combinations | ||||
| from tensorflow.python.distribute import values | ||||
| from tensorflow.python.eager import def_function | ||||
| from tensorflow.python.eager import test | ||||
| from tensorflow.python.framework import dtypes | ||||
| from tensorflow.python.framework import sparse_tensor | ||||
| from tensorflow.python.framework import tensor_shape | ||||
| from tensorflow.python.framework import tensor_spec | ||||
| from tensorflow.python.ops.ragged import ragged_tensor as ragged_tensor_lib | ||||
| 
 | ||||
| 
 | ||||
| class DistributedIteratorTest(test.TestCase, | ||||
|                               parameterized.TestCase): | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           input_type=["dataset"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|           ], | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testTypeSpec(self, input_type, distribution, | ||||
|                    enable_get_next_as_optional): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("DistributedIterator has CompositeTensor support in " | ||||
|                     "TF 2 only.") | ||||
|     dataset = dataset_ops.DatasetV2.range(10).batch(2) | ||||
| 
 | ||||
|     distribution.extended.experimental_enable_get_next_as_optional = ( | ||||
|         enable_get_next_as_optional) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
|     with distribution.scope(): | ||||
|       iterator = iter(dist_dataset) | ||||
| 
 | ||||
|     spec = iterator._type_spec | ||||
|     self.assertEqual(spec._input_workers, iterator._input_workers) | ||||
|     self.assertEqual(spec._element_spec._value_specs, | ||||
|                      (tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, | ||||
|                                              name=None), | ||||
|                       tensor_spec.TensorSpec(shape=(None,), dtype=dtypes.int64, | ||||
|                                              name=None))) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           input_type=["dataset"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|           ], | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testTypeSpecRoundTrip(self, input_type, | ||||
|                             distribution, enable_get_next_as_optional): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("DistributedIterator CompositeTensor support is only " | ||||
|                     "present in TF 2.0 only.") | ||||
| 
 | ||||
|     dataset = dataset_ops.DatasetV2.range(10).batch(2) | ||||
| 
 | ||||
|     distribution.extended.experimental_enable_get_next_as_optional = ( | ||||
|         enable_get_next_as_optional) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
|     with distribution.scope(): | ||||
|       iterator = iter(dist_dataset) | ||||
| 
 | ||||
|     spec = iterator._type_spec | ||||
| 
 | ||||
|     tensor_list = spec._to_components(iterator) | ||||
|     re_iterator = spec._from_components(tensor_list) | ||||
| 
 | ||||
|     self.assertEqual(iterator._input_workers, re_iterator._input_workers) | ||||
|     self.assertAllEqual(iterator._iterators, re_iterator._iterators) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           input_type=["dataset"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|           ], | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testDoesNotTriggerFunctionTracing(self, input_type, distribution, | ||||
|                                         enable_get_next_as_optional): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("DistributedIterator CompositeTensor support is only " | ||||
|                     "present in TF 2.0 only.") | ||||
| 
 | ||||
|     trace_count = [0] | ||||
| 
 | ||||
|     @def_function.function | ||||
|     def f(iterator): | ||||
|       trace_count[0] += 1 | ||||
|       counter = np.int64(0) | ||||
|       for _ in range(5): | ||||
|         next(iterator) | ||||
|         counter += 1 | ||||
|       return counter | ||||
| 
 | ||||
|     dataset = dataset_ops.DatasetV2.range(10).batch(2) | ||||
| 
 | ||||
|     distribution.extended.experimental_enable_get_next_as_optional = ( | ||||
|         enable_get_next_as_optional) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
|     with distribution.scope(): | ||||
|       for _ in range(3): | ||||
|         iterator = iter(dist_dataset) | ||||
|         counter = f(iterator) | ||||
| 
 | ||||
|         self.assertEqual(trace_count[0], 1) | ||||
|         self.assertEqual(counter, 5) | ||||
| 
 | ||||
| 
 | ||||
| class InputTypeSpecTest(test.TestCase, parameterized.TestCase): | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.one_device_strategy, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu, | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|               strategy_combinations.central_storage_strategy_with_two_gpus, | ||||
|           ], | ||||
|           input_type=["dataset", "dataset_fn"], | ||||
|       )) | ||||
|   def testInputSignatureForPerReplicaValues(self, distribution, input_type): | ||||
|     def dataset_fn(ctx): | ||||
|       del ctx  # unused | ||||
|       return dataset_ops.DatasetV2.from_tensor_slices( | ||||
|           np.ones([10, 12]).astype(np.float32)).batch(4) | ||||
| 
 | ||||
|     if input_type == "dataset": | ||||
|       ds = distribution.experimental_distribute_dataset( | ||||
|           dataset_fn(distribute_lib.InputContext())) | ||||
|       type_spec = ds.element_spec | ||||
|     else: | ||||
|       ds = distribution.experimental_distribute_datasets_from_function( | ||||
|           dataset_fn) | ||||
|       iterator = iter(ds) | ||||
|       type_spec = iterator.element_spec | ||||
| 
 | ||||
|     @def_function.function(input_signature=[type_spec]) | ||||
|     def process_inputs(inputs): | ||||
|       distribution.run(lambda inputs: inputs, args=(inputs,)) | ||||
| 
 | ||||
|     for x in ds: | ||||
|       process_inputs(x) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.one_device_strategy, | ||||
|               strategy_combinations.mirrored_strategy_with_one_cpu, | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|               strategy_combinations.central_storage_strategy_with_two_gpus, | ||||
|           ], | ||||
|       )) | ||||
|   def testInputSignatureForNestedPerReplicaValues(self, distribution): | ||||
|     a = np.ones((10, 2)) * 5 | ||||
|     b = np.ones((10, 3)) * 6 | ||||
|     dataset = dataset_ops.DatasetV2.from_tensor_slices((a, b)).batch(2) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
| 
 | ||||
|     @def_function.function(input_signature=[dist_dataset.element_spec]) | ||||
|     def process_inputs(inputs): | ||||
|       distribution.run(lambda inputs: inputs, args=(inputs,)) | ||||
| 
 | ||||
|     for x in dist_dataset: | ||||
|       process_inputs(x) | ||||
| 
 | ||||
| 
 | ||||
| class RaggedTensorDistributedIteratorTest(test.TestCase, | ||||
|                                           parameterized.TestCase): | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|           ], | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testTypeSpec(self, distribution, enable_get_next_as_optional): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("DistributedIterator has CompositeTensor support in " | ||||
|                     "TF 2.0 only.") | ||||
|     ctx = distribute_lib.InputContext() | ||||
|     batch_size = ctx.get_per_replica_batch_size(8) | ||||
|     # Use 20 which isn't divisible by 8 to test partial batch behavior. | ||||
|     row_lengths = np.mod(np.arange(20), 4).astype(np.int64) | ||||
|     ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( | ||||
|         np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) | ||||
|     dataset = dataset_ops.DatasetV2.from_tensor_slices({ | ||||
|         "dense": ragged_tensor.to_tensor(), | ||||
|         "ragged": ragged_tensor, | ||||
|         "sparse": ragged_tensor.to_sparse(), | ||||
|     }) | ||||
|     dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) | ||||
|     dataset = dataset.batch(batch_size) | ||||
| 
 | ||||
|     distribution.extended.experimental_enable_get_next_as_optional = ( | ||||
|         enable_get_next_as_optional) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
|     with distribution.scope(): | ||||
|       iterator = iter(dist_dataset) | ||||
| 
 | ||||
|     spec = iterator._type_spec | ||||
|     self.assertEqual(spec._input_workers, iterator._input_workers) | ||||
|     self.assertEqual( | ||||
|         spec._element_spec, { | ||||
|             "sparse": | ||||
|                 values.PerReplicaSpec( | ||||
|                     sparse_tensor.SparseTensorSpec( | ||||
|                         tensor_shape.TensorShape([None, 3]), dtypes.float32), | ||||
|                     sparse_tensor.SparseTensorSpec( | ||||
|                         tensor_shape.TensorShape([None, 3]), dtypes.float32)), | ||||
|             "dense": | ||||
|                 values.PerReplicaSpec( | ||||
|                     tensor_spec.TensorSpec( | ||||
|                         shape=(None, 3), dtype=dtypes.float32, name=None), | ||||
|                     tensor_spec.TensorSpec( | ||||
|                         shape=(None, 3), dtype=dtypes.float32, name=None)), | ||||
|             "ragged": | ||||
|                 values.PerReplicaSpec( | ||||
|                     ragged_tensor_lib.RaggedTensorSpec( | ||||
|                         tensor_shape.TensorShape([None, None]), dtypes.float32, | ||||
|                         1, dtypes.int64), | ||||
|                     ragged_tensor_lib.RaggedTensorSpec( | ||||
|                         tensor_shape.TensorShape([None, None]), dtypes.float32, | ||||
|                         1, dtypes.int64)) | ||||
|         }) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|           ], | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testTypeSpecRoundTrip(self, distribution, enable_get_next_as_optional): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("DistributedIterator CompositeTensor support is only " | ||||
|                     "present in TF 2.0 only.") | ||||
| 
 | ||||
|     ctx = distribute_lib.InputContext() | ||||
|     batch_size = ctx.get_per_replica_batch_size(8) | ||||
|     # Use 20 which isn't divisible by 8 to test partial batch behavior. | ||||
|     row_lengths = np.mod(np.arange(20), 4).astype(np.int64) | ||||
|     ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( | ||||
|         np.repeat(np.arange(20, dtype=np.float32), row_lengths), row_lengths) | ||||
|     dataset = dataset_ops.DatasetV2.from_tensor_slices({ | ||||
|         "dense": ragged_tensor.to_tensor(), | ||||
|         "ragged": ragged_tensor, | ||||
|         "sparse": ragged_tensor.to_sparse(), | ||||
|     }) | ||||
|     dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) | ||||
|     dataset = dataset.batch(batch_size) | ||||
| 
 | ||||
|     distribution.extended.experimental_enable_get_next_as_optional = ( | ||||
|         enable_get_next_as_optional) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
|     with distribution.scope(): | ||||
|       iterator = iter(dist_dataset) | ||||
| 
 | ||||
|     spec = iterator._type_spec | ||||
| 
 | ||||
|     tensor_list = spec._to_components(iterator) | ||||
|     re_iterator = spec._from_components(tensor_list) | ||||
| 
 | ||||
|     self.assertEqual(iterator._input_workers, re_iterator._input_workers) | ||||
|     self.assertAllEqual(iterator._iterators, re_iterator._iterators) | ||||
| 
 | ||||
|   @combinations.generate( | ||||
|       combinations.combine( | ||||
|           mode=["eager"], | ||||
|           distribution=[ | ||||
|               strategy_combinations.mirrored_strategy_with_gpu_and_cpu, | ||||
|               strategy_combinations.tpu_strategy, | ||||
|           ], | ||||
|           enable_get_next_as_optional=[True, False])) | ||||
|   def testDoesNotTriggerFunctionTracing(self, distribution, | ||||
|                                         enable_get_next_as_optional): | ||||
|     if not tf2.enabled(): | ||||
|       self.skipTest("DistributedIterator CompositeTensor support is only " | ||||
|                     "present in TF 2.0 only.") | ||||
| 
 | ||||
|     trace_count = [0] | ||||
| 
 | ||||
|     @def_function.function | ||||
|     def f(iterator): | ||||
|       trace_count[0] += 1 | ||||
|       counter = np.int64(0) | ||||
|       for _ in range(5): | ||||
|         next(iterator) | ||||
|         counter += 1 | ||||
|       return counter | ||||
| 
 | ||||
|     ctx = distribute_lib.InputContext() | ||||
|     batch_size = ctx.get_per_replica_batch_size(8) | ||||
|     # Use 20 which isn't divisible by 8 to test partial batch behavior. | ||||
|     row_lengths = np.mod(np.arange(50), 4).astype(np.int64) | ||||
|     ragged_tensor = ragged_tensor_lib.RaggedTensor.from_row_lengths( | ||||
|         np.repeat(np.arange(50, dtype=np.float32), row_lengths), row_lengths) | ||||
|     dataset = dataset_ops.DatasetV2.from_tensor_slices({ | ||||
|         "dense": ragged_tensor.to_tensor(), | ||||
|         "ragged": ragged_tensor, | ||||
|         "sparse": ragged_tensor.to_sparse(), | ||||
|     }) | ||||
|     dataset = dataset.shard(ctx.num_input_pipelines, ctx.input_pipeline_id) | ||||
|     dataset = dataset.batch(batch_size) | ||||
| 
 | ||||
|     distribution.extended.experimental_enable_get_next_as_optional = ( | ||||
|         enable_get_next_as_optional) | ||||
| 
 | ||||
|     dist_dataset = distribution.experimental_distribute_dataset(dataset) | ||||
|     with distribution.scope(): | ||||
|       for _ in range(3): | ||||
|         iterator = iter(dist_dataset) | ||||
|         counter = f(iterator) | ||||
| 
 | ||||
|         self.assertEqual(trace_count[0], 1) | ||||
|         self.assertEqual(counter, 5) | ||||
| 
 | ||||
| 
 | ||||
| if __name__ == "__main__": | ||||
|   test.main() | ||||
| @ -1161,12 +1161,7 @@ class DataHandler(object): | ||||
|         if self._insufficient_data:  # Set by `catch_stop_iteration`. | ||||
|           break | ||||
|         if self._adapter.should_recreate_iterator(): | ||||
|           if ds_context.has_strategy(): | ||||
|             # TODO(b/138326910): remove this when MultiDeviceIterator is a | ||||
|             # CompositeTensor (unless this is more efficient) | ||||
|             data_iterator._initializer  # pylint: disable=pointless-statement, protected-access | ||||
|           else: | ||||
|             data_iterator = iter(self._dataset) | ||||
|           data_iterator = iter(self._dataset) | ||||
|         yield epoch, data_iterator | ||||
|         self._adapter.on_epoch_end() | ||||
| 
 | ||||
|  | ||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user