Construct the correct shape for zero-batch size in RebatchDataset
PiperOrigin-RevId: 337441780 Change-Id: Ic70551b00882711120dc0f21f9d23fd1e9c9155c
This commit is contained in:
parent
f0844f4065
commit
3ba72a8d69
@ -417,7 +417,6 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
|
|||||||
std::vector<Tensor> slices;
|
std::vector<Tensor> slices;
|
||||||
slices.reserve(tensors_.size());
|
slices.reserve(tensors_.size());
|
||||||
for (const auto& tensor : tensors_) {
|
for (const auto& tensor : tensors_) {
|
||||||
Tensor slice = tensor.Slice(offset_, slice_end);
|
|
||||||
slices.push_back(tensor.Slice(offset_, slice_end));
|
slices.push_back(tensor.Slice(offset_, slice_end));
|
||||||
}
|
}
|
||||||
slices_to_concatenate.push_back(std::move(slices));
|
slices_to_concatenate.push_back(std::move(slices));
|
||||||
@ -452,8 +451,28 @@ class RebatchDatasetV2Op : public UnaryDatasetOpKernel {
|
|||||||
if (desired_batch_size == 0) {
|
if (desired_batch_size == 0) {
|
||||||
DCHECK_EQ(batch_size, 0);
|
DCHECK_EQ(batch_size, 0);
|
||||||
DCHECK_EQ(slices_to_concatenate.size(), 0);
|
DCHECK_EQ(slices_to_concatenate.size(), 0);
|
||||||
for (const auto& dtype : dataset()->output_dtypes()) {
|
for (int i = 0; i < dataset()->output_dtypes().size(); ++i) {
|
||||||
out_tensors->push_back(Tensor(dtype));
|
if (dataset()->output_shapes()[i].unknown_rank()) {
|
||||||
|
// For unknown rank tensors, we just create a empty Tensor since
|
||||||
|
// it doesn't matter what shape it is.
|
||||||
|
out_tensors->push_back(Tensor(dataset()->output_dtypes()[i]));
|
||||||
|
} else {
|
||||||
|
auto dim_sizes = dataset()->output_shapes()[i].dim_sizes();
|
||||||
|
|
||||||
|
// The output batch size is always zero since the desired batch
|
||||||
|
// size is zero.
|
||||||
|
dim_sizes[0] = 0;
|
||||||
|
|
||||||
|
// Handle unknown dimensions by setting any unknown dimensions to
|
||||||
|
// zero since there isn't any data anyway.
|
||||||
|
for (int j = 1; j < dim_sizes.size(); ++j) {
|
||||||
|
if (dim_sizes[j] == -1) dim_sizes[j] = 0;
|
||||||
|
}
|
||||||
|
|
||||||
|
TensorShape tensor_shape(dim_sizes);
|
||||||
|
out_tensors->push_back(
|
||||||
|
Tensor(dataset()->output_dtypes()[i], tensor_shape));
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|||||||
@ -287,6 +287,40 @@ class RebatchDatasetTest(test_base.DatasetTestBase, parameterized.TestCase):
|
|||||||
expected_output = [[0], [1], [2], [3], [], [4], [5], [6], [7], []]
|
expected_output = [[0], [1], [2], [3], [], [4], [5], [6], [7], []]
|
||||||
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.default_test_combinations(),
|
||||||
|
combinations.combine(drop_remainder=[True, False])))
|
||||||
|
def testEmptyFirstSplits(self, drop_remainder):
|
||||||
|
dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
|
||||||
|
rebatched_dataset = distribute._RebatchDataset(
|
||||||
|
dataset, batch_sizes=[0, 1], drop_remainder=drop_remainder)
|
||||||
|
|
||||||
|
expected_shapes = [[None]]
|
||||||
|
self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
|
||||||
|
|
||||||
|
# We have an extra element at the end because if the desired batch size is
|
||||||
|
# zero, then we never read any inputs from the input_dataset at all, so we
|
||||||
|
# will keep producting empty outputs until we reach a non zero desired batch
|
||||||
|
# size split.
|
||||||
|
expected_output = [[], [0], [], [1], [], [2], [], [3],
|
||||||
|
[], [4], [], [5], [], [6], [], [7], []]
|
||||||
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
|
@combinations.generate(
|
||||||
|
combinations.times(test_base.default_test_combinations(),
|
||||||
|
combinations.combine(drop_remainder=[True, False])))
|
||||||
|
def testEmptyLastSplits(self, drop_remainder):
|
||||||
|
dataset = dataset_ops.Dataset.range(8).batch(4, drop_remainder=True)
|
||||||
|
rebatched_dataset = distribute._RebatchDataset(
|
||||||
|
dataset, batch_sizes=[1, 0], drop_remainder=drop_remainder)
|
||||||
|
|
||||||
|
expected_shapes = [[None]]
|
||||||
|
self.assertEqual(expected_shapes, _flat_shapes(rebatched_dataset))
|
||||||
|
|
||||||
|
expected_output = [[0], [], [1], [], [2], [], [3], [],
|
||||||
|
[4], [], [5], [], [6], [], [7], []]
|
||||||
|
self.assertDatasetProduces(rebatched_dataset, expected_output)
|
||||||
|
|
||||||
@combinations.generate(
|
@combinations.generate(
|
||||||
combinations.times(test_base.default_test_combinations(),
|
combinations.times(test_base.default_test_combinations(),
|
||||||
combinations.combine(drop_remainder=[True, False])))
|
combinations.combine(drop_remainder=[True, False])))
|
||||||
|
|||||||
Loading…
x
Reference in New Issue
Block a user