[tf.data] Preserving accurate cardinality information for group_by_window transformation.

PiperOrigin-RevId: 338679574
Change-Id: Ibba771ef1050f5fbf08daf6dee251e4463f03e09
This commit is contained in:
Jiri Simsa 2020-10-23 08:39:30 -07:00 committed by TensorFlower Gardener
parent 7cb2593b93
commit 5479744e3f
3 changed files with 43 additions and 4 deletions

View File

@ -107,6 +107,14 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
return "GroupByWindowDatasetOp::Dataset";
}
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality) {
return n;
}
return kUnknownCardinality;
}
Status InputDatasets(
std::vector<const DatasetBase*>* inputs) const override {
inputs->push_back(input_);

View File

@ -101,7 +101,8 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
n_bucket_elements):
# Calculate the expected sum across all batches of a specific sequence length.
# Calculate the expected sum across all batches of a specific sequence
# length.
expected_sums[length] = \
(bucket_elements - bucket_elements % batch_size) * length
# Calculate the expected occurrence of individual batch sizes.
@ -116,8 +117,8 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
# Produce 1 batch for each bucket
elements = []
for bucket_elements, length in zip(n_bucket_elements, lengths):
# Using only full sequences (opposed to the strategy employed in `testBucket`) makes
# checking the sum a lot easier.
# Using only full sequences (opposed to the strategy employed in
# `testBucket`) makes checking the sum a lot easier.
record_len = length
for _ in range(bucket_elements):
elements.append([1] * record_len)
@ -177,7 +178,8 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
generated_sums[length] += batch_sum
for l in lengths:
# Make sure the sum of the batch contents is correct for the individual sequence lengths.
# Make sure the sum of the batch contents is correct for the individual
# sequence lengths.
self.assertEqual(
generated_sums[l], expected_sums[l], "Tensor sums did not match! "
"expected: {}, generated: {}".format(expected_sums, generated_sums))
@ -261,6 +263,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
_test_bucket_by_padding(param_no_padding)
@combinations.generate(test_base.default_test_combinations())
def testPadToBoundary(self):
boundaries = [10, 20, 30]
@ -308,6 +311,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
sorted(lengths_val))
@combinations.generate(test_base.default_test_combinations())
def testPadToBoundaryNoExtraneousPadding(self):
boundaries = [3, 7, 11]
@ -460,6 +464,25 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
expected_batches = _compute_expected_batches(param_drop_remainder)
self.assertEqual(batches, expected_batches)
@combinations.generate(test_base.default_test_combinations())
def testCardinality(self):
boundaries = [3, 7, 11]
batch_sizes = [2, 2, 2, 2]
lengths = range(1, 11)
def element_gen():
for length in lengths:
yield ([1] * length,)
element_len = lambda element: array_ops.shape(element)[0]
dataset = dataset_ops.Dataset.from_generator(
element_gen, (dtypes.int64,), ([None],)).repeat().apply(
grouping.bucket_by_sequence_length(
element_len, boundaries, batch_sizes,
pad_to_bucket_boundary=True))
self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.INFINITE)
if __name__ == "__main__":
test.main()

View File

@ -341,6 +341,14 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
get_next = self.getNext(dataset)
self.evaluate(get_next())
@combinations.generate(test_base.default_test_combinations())
def testGroupByWindowCardinality(self):
dataset = dataset_ops.Dataset.range(1).repeat().apply(
grouping.group_by_window(
lambda x: x % 2,
lambda key, window: dataset_ops.Dataset.from_tensors(key), 4))
self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.INFINITE)
if __name__ == "__main__":
test.main()