Construct the correct shape for zero-batch size in RebatchDataset

PiperOrigin-RevId: 337441780
Change-Id: Ic70551b00882711120dc0f21f9d23fd1e9c9155c
This commit is contained in:
Frank Chen 2020-10-15 21:39:12 -07:00 committed by TensorFlower Gardener
parent f0844f4065
commit 3ba72a8d69
2 changed files with 56 additions and 3 deletions

View File

@ -417,7 +417,6 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
std::vector<Tensor> slices;
slices.reserve(tensors_.size());
for (const auto& tensor : tensors_) {
Tensor slice = tensor.Slice(offset_, slice_end);
slices.push_back(tensor.Slice(offset_, slice_end));
}
slices_to_concatenate.push_back(std::move(slices));
@ -452,8 +451,28 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
if (desired_batch_size == 0) {
DCHECK_EQ(batch_size, 0);
DCHECK_EQ(slices_to_concatenate.size(), 0);
for (const auto& dtype : dataset()->output_dtypes()) {
out_tensors->push_back(Tensor(dtype));
for (int i = 0; i < dataset()->output_dtypes().size(); ++i) {
if (dataset()->output_shapes()[i].unknown_rank()) {
// For unknown rank tensors, we just create a empty Tensor since
// it doesn't matter what shape it is.
out_tensors->push_back(Tensor(dataset()->output_dtypes()[i]));
} else {
auto dim_sizes = dataset()->output_shapes()[i].dim_sizes();
// The output batch size is always zero since the desired batch
// size is zero.
dim_sizes[0] = 0;
// Handle unknown dimensions by setting any unknown dimensions to
// zero since there isn't any data anyway.
for (int j = 1; j < dim_sizes.size(); ++j) {
if (dim_sizes[j] == -1) dim_sizes[j] = 0;
}
TensorShape tensor_shape(dim_sizes);
out_tensors->push_back(
Tensor(dataset()->output_dtypes()[i], tensor_shape));
}
}
return Status::OK();
}

View File

@ -287,6 +287,40 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
expected_output = [[0], [1], [2], [3], [], [4], [5], [6], [7], []]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testEmptyFirstSplits(self, drop_remainder):
dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
rebatched_dataset = distribute._RebatchDataset(
dataset, batch_sizes=[0, 1], drop_remainder=drop_remainder)
expected_shapes = [[None]]
self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
# We have an extra element at the end because if the desired batch size is
# zero, then we never read any inputs from the input_dataset at all, so we
# will keep producting empty outputs until we reach a non zero desired batch
# size split.
expected_output = [[], [0], [], [1], [], [2], [], [3],
[], [4], [], [5], [], [6], [], [7], []]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))
def testEmptyLastSplits(self, drop_remainder):
dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
rebatched_dataset = distribute._RebatchDataset(
dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder)
expected_shapes = [[None]]
self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
expected_output = [[0], [], [1], [], [2], [], [3], [],
[4], [], [5], [], [6], [], [7], []]
self.assertDatasetProduces(rebatched_dataset, expected_output)
@combinations.generate(
combinations.times(test_base.default_test_combinations(),
combinations.combine(drop_remainder=[True, False])))