diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 26bc9a087fb..a1e76ea0ecb 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -349,33 +349,6 @@ class DistributedIteratorBase(distribute_types.Iterator): results.append(result) replicas = results - # Some dimensions in `replicas` will become unknown after we conditionally - # return the real tensors or the dummy tensors. We fix the input shapes by - # using the shapes from `out_of_range_replicas` because it is calling - # get_next() inside. - flattened_replicas = nest.flatten(replicas) - for i, replica_data in enumerate(nest.flatten(out_of_range_replicas)): - for target, source in zip( - nest.flatten(flattened_replicas[i], expand_composites=True), - nest.flatten(replica_data, expand_composites=True)): - target.set_shape(source.get_shape()) - # `SparseTensor` shape is not determined by the shape of its component - # tensors. Rather, its shape depends on a tensor's values. - if sparse_tensor.is_sparse(replica_data) and replica_data.get_shape(): - dense_shape = replica_data.get_shape() - with ops.device(flattened_replicas[i].op.device): - # For partially defined shapes, fill in missing values from tensor. - if not dense_shape.is_fully_defined(): - dense_shape = array_ops.stack([ - flattened_replicas[i].dense_shape[j] if dim is None else dim - for j, dim in enumerate(dense_shape.as_list()) - ]) - flattened_replicas[i] = sparse_tensor.SparseTensor( - indices=flattened_replicas[i].indices, - values=flattened_replicas[i].values, - dense_shape=dense_shape) - replicas = nest.pack_sequence_as(replicas, flattened_replicas) - return values.regroup(replicas) @@ -1048,6 +1021,34 @@ def _dummy_tensor_fn(value_structure): return nest.map_structure(create_dummy_tensor, value_structure) +def _recover_shape_fn(data, value_structure): + """Recover the shape of `data` the same as shape of `value_structure`.""" + + flattened_data = nest.flatten(data) + for i, spec in enumerate(nest.flatten(value_structure)): + for target, source in zip( + nest.flatten(flattened_data[i], expand_composites=True), + nest.flatten(spec, expand_composites=True)): + target.set_shape(source.shape) + # `SparseTensor` shape is not determined by the shape of its component + # tensors. Rather, its shape depends on a tensor's values. + if isinstance(spec, sparse_tensor.SparseTensorSpec) and spec.shape: + dense_shape = spec.shape + with ops.device(flattened_data[i].op.device): + # For partially defined shapes, fill in missing values from tensor. + if not dense_shape.is_fully_defined(): + dense_shape = array_ops.stack([ + flattened_data[i].dense_shape[j] if dim is None else dim + for j, dim in enumerate(dense_shape.as_list()) + ]) + flattened_data[i] = sparse_tensor.SparseTensor( + indices=flattened_data[i].indices, + values=flattened_data[i].values, + dense_shape=dense_shape) + data = nest.pack_sequence_as(data, flattened_data) + return data + + class _SingleWorkerDatasetIteratorBase(object): """Iterator for a single `tf.data.Dataset`.""" @@ -1132,6 +1133,13 @@ class _SingleWorkerDatasetIteratorBase(object): lambda: _dummy_tensor_fn(data.value_structure), strict=True, ) + # Some dimensions in `replicas` will become unknown after we + # conditionally return the real tensors or the dummy tensors. Recover + # the shapes from `data.value_structure`. We only need to do this in + # non eager mode because we always know the runtime shape of the + # tensors in eager mode. + if not context.executing_eagerly(): + real_data = _recover_shape_fn(real_data, data.value_structure) result.append(real_data) # pylint: enable=cell-var-from-loop # pylint: enable=unnecessary-lambda diff --git a/tensorflow/python/distribute/input_lib_test.py b/tensorflow/python/distribute/input_lib_test.py index 60212f7a3b7..2114c4e6bda 100644 --- a/tensorflow/python/distribute/input_lib_test.py +++ b/tensorflow/python/distribute/input_lib_test.py @@ -48,6 +48,7 @@ from tensorflow.python.framework import composite_tensor from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.framework import sparse_tensor +from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import sparse_ops @@ -612,6 +613,41 @@ class DistributedIteratorSingleWorkerTest(DistributedIteratorTestBase, else: self.assertAllEqual(first_epoch, second_epoch) + @combinations.generate( + combinations.combine( + mode=["eager"], + distribution=[ + strategy_combinations.one_device_strategy, + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.central_storage_strategy_with_two_gpus, + ])) + def testGetNextOptionalShape(self, distribution): + batch_size = 8 + dataset = dataset_ops.DatasetV2.from_tensor_slices({ + "feature": array_ops.ones([batch_size, 10]), + "label": array_ops.ones([batch_size]), + }) + dataset = dataset.batch(batch_size, drop_remainder=True) + dist_dataset = distribution.experimental_distribute_dataset(dataset) + per_replica_batch_size = batch_size // distribution.num_replicas_in_sync + + @def_function.function + def train_fn(): + for data in dist_dataset: + data = nest.map_structure(distribution.experimental_local_results, data) + feature = data["feature"] + label = data["label"] + + # Asser the shapes are still staic from all replicas. + for replica_id in range(distribution.num_replicas_in_sync): + self.assertEqual([per_replica_batch_size, 10], + feature[replica_id].shape) + self.assertEqual([per_replica_batch_size], label[replica_id].shape) + + train_fn() + class DistributedIteratorTensorTypeTest(DistributedIteratorTestBase, parameterized.TestCase):