From 3ba72a8d69a27228e6243026e2e0768027f05561 Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Thu, 15 Oct 2020 21:39:12 -0700 Subject: [PATCH] Construct the correct shape for zero-batch size in RebatchDataset PiperOrigin-RevId: 337441780 Change-Id: Ic70551b00882711120dc0f21f9d23fd1e9c9155c --- .../data/experimental/rebatch_dataset_op.cc | 25 ++++++++++++-- .../kernel_tests/rebatch_dataset_test.py | 34 +++++++++++++++++++ 2 files changed, 56 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc index e2cbe7d9dcc..7a65baaa680 100644 --- a/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/rebatch_dataset_op.cc @@ -417,7 +417,6 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { std::vector slices; slices.reserve(tensors_.size()); for (const auto& tensor : tensors_) { - Tensor slice = tensor.Slice(offset_, slice_end); slices.push_back(tensor.Slice(offset_, slice_end)); } slices_to_concatenate.push_back(std::move(slices)); @@ -452,8 +451,28 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel { if (desired_batch_size == 0) { DCHECK_EQ(batch_size, 0); DCHECK_EQ(slices_to_concatenate.size(), 0); - for (const auto& dtype : dataset()->output_dtypes()) { - out_tensors->push_back(Tensor(dtype)); + for (int i = 0; i < dataset()->output_dtypes().size(); ++i) { + if (dataset()->output_shapes()[i].unknown_rank()) { + // For unknown rank tensors, we just create a empty Tensor since + // it doesn't matter what shape it is. + out_tensors->push_back(Tensor(dataset()->output_dtypes()[i])); + } else { + auto dim_sizes = dataset()->output_shapes()[i].dim_sizes(); + + // The output batch size is always zero since the desired batch + // size is zero. + dim_sizes[0] = 0; + + // Handle unknown dimensions by setting any unknown dimensions to + // zero since there isn't any data anyway. + for (int j = 1; j < dim_sizes.size(); ++j) { + if (dim_sizes[j] == -1) dim_sizes[j] = 0; + } + + TensorShape tensor_shape(dim_sizes); + out_tensors->push_back( + Tensor(dataset()->output_dtypes()[i], tensor_shape)); + } } return Status::OK(); } diff --git a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py index 8175480182f..941ce327555 100644 --- a/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/rebatch_dataset_test.py @@ -287,6 +287,40 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase): expected_output = [[0], [1], [2], [3], [], [4], [5], [6], [7], []] self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) + def testEmptyFirstSplits(self, drop_remainder): + dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True) + rebatched_dataset = distribute._RebatchDataset( + dataset, batch_sizes=[0, 1], drop_remainder=drop_remainder) + + expected_shapes = [[None]] + self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) + + # We have an extra element at the end because if the desired batch size is + # zero, then we never read any inputs from the input_dataset at all, so we + # will keep producting empty outputs until we reach a non zero desired batch + # size split. + expected_output = [[], [0], [], [1], [], [2], [], [3], + [], [4], [], [5], [], [6], [], [7], []] + self.assertDatasetProduces(rebatched_dataset, expected_output) + + @combinations.generate( + combinations.times(test_base.default_test_combinations(), + combinations.combine(drop_remainder=[True, False]))) + def testEmptyLastSplits(self, drop_remainder): + dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True) + rebatched_dataset = distribute._RebatchDataset( + dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder) + + expected_shapes = [[None]] + self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset)) + + expected_output = [[0], [], [1], [], [2], [], [3], [], + [4], [], [5], [], [6], [], [7], []] + self.assertDatasetProduces(rebatched_dataset, expected_output) + @combinations.generate( combinations.times(test_base.default_test_combinations(), combinations.combine(drop_remainder=[True, False])))