Construct the correct shape for zero-batch size in RebatchDataset
PiperOrigin-RevId: 337441780 Change-Id: Ic70551b00882711120dc0f21f9d23fd1e9c9155c
This commit is contained in:
parent
f0844f4065
commit
3ba72a8d69
@ -417,7 +417,6 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
|
||||
std::vector<Tensor> slices;
|
||||
slices.reserve(tensors_.size());
|
||||
for (const auto& tensor : tensors_) {
|
||||
Tensor slice = tensor.Slice(offset_, slice_end);
|
||||
slices.push_back(tensor.Slice(offset_, slice_end));
|
||||
}
|
||||
slices_to_concatenate.push_back(std::move(slices));
|
||||
@ -452,8 +451,28 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
|
||||
if (desired_batch_size == 0) {
|
||||
DCHECK_EQ(batch_size, 0);
|
||||
DCHECK_EQ(slices_to_concatenate.size(), 0);
|
||||
for (const auto& dtype : dataset()->output_dtypes()) {
|
||||
out_tensors->push_back(Tensor(dtype));
|
||||
for (int i = 0; i < dataset()->output_dtypes().size(); ++i) {
|
||||
if (dataset()->output_shapes()[i].unknown_rank()) {
|
||||
// For unknown rank tensors, we just create a empty Tensor since
|
||||
// it doesn't matter what shape it is.
|
||||
out_tensors->push_back(Tensor(dataset()->output_dtypes()[i]));
|
||||
} else {
|
||||
auto dim_sizes = dataset()->output_shapes()[i].dim_sizes();
|
||||
|
||||
// The output batch size is always zero since the desired batch
|
||||
// size is zero.
|
||||
dim_sizes[0] = 0;
|
||||
|
||||
// Handle unknown dimensions by setting any unknown dimensions to
|
||||
// zero since there isn't any data anyway.
|
||||
for (int j = 1; j < dim_sizes.size(); ++j) {
|
||||
if (dim_sizes[j] == -1) dim_sizes[j] = 0;
|
||||
}
|
||||
|
||||
TensorShape tensor_shape(dim_sizes);
|
||||
out_tensors->push_back(
|
||||
Tensor(dataset()->output_dtypes()[i], tensor_shape));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -287,6 +287,40 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
expected_output = [[0], [1], [2], [3], [], [4], [5], [6], [7], []]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
def testEmptyFirstSplits(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
|
||||
rebatched_dataset = distribute._RebatchDataset(
|
||||
dataset, batch_sizes=[0, 1], drop_remainder=drop_remainder)
|
||||
|
||||
expected_shapes = [[None]]
|
||||
self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
|
||||
|
||||
# We have an extra element at the end because if the desired batch size is
|
||||
# zero, then we never read any inputs from the input_dataset at all, so we
|
||||
# will keep producting empty outputs until we reach a non zero desired batch
|
||||
# size split.
|
||||
expected_output = [[], [0], [], [1], [], [2], [], [3],
|
||||
[], [4], [], [5], [], [6], [], [7], []]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
def testEmptyLastSplits(self, drop_remainder):
|
||||
dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
|
||||
rebatched_dataset = distribute._RebatchDataset(
|
||||
dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder)
|
||||
|
||||
expected_shapes = [[None]]
|
||||
self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
|
||||
|
||||
expected_output = [[0], [], [1], [], [2], [], [3], [],
|
||||
[4], [], [5], [], [6], [], [7], []]
|
||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.times(test_base.default_test_combinations(),
|
||||
combinations.combine(drop_remainder=[True, False])))
|
||||
|
Loading…
x
Reference in New Issue
Block a user