[tf.data] Preserving accurate cardinality information for group_by_window
transformation.
PiperOrigin-RevId: 338679574 Change-Id: Ibba771ef1050f5fbf08daf6dee251e4463f03e09
This commit is contained in:
parent
7cb2593b93
commit
5479744e3f
@ -107,6 +107,14 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "GroupByWindowDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
int64 Cardinality() const override {
|
||||
int64 n = input_->Cardinality();
|
||||
if (n == kInfiniteCardinality) {
|
||||
return n;
|
||||
}
|
||||
return kUnknownCardinality;
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
|
@ -101,7 +101,8 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
|
||||
for length, batch_size, bucket_elements in zip(lengths, batch_sizes,
|
||||
n_bucket_elements):
|
||||
# Calculate the expected sum across all batches of a specific sequence length.
|
||||
# Calculate the expected sum across all batches of a specific sequence
|
||||
# length.
|
||||
expected_sums[length] = \
|
||||
(bucket_elements - bucket_elements % batch_size) * length
|
||||
# Calculate the expected occurrence of individual batch sizes.
|
||||
@ -116,8 +117,8 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
# Produce 1 batch for each bucket
|
||||
elements = []
|
||||
for bucket_elements, length in zip(n_bucket_elements, lengths):
|
||||
# Using only full sequences (opposed to the strategy employed in `testBucket`) makes
|
||||
# checking the sum a lot easier.
|
||||
# Using only full sequences (opposed to the strategy employed in
|
||||
# `testBucket`) makes checking the sum a lot easier.
|
||||
record_len = length
|
||||
for _ in range(bucket_elements):
|
||||
elements.append([1] * record_len)
|
||||
@ -177,7 +178,8 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
generated_sums[length] += batch_sum
|
||||
|
||||
for l in lengths:
|
||||
# Make sure the sum of the batch contents is correct for the individual sequence lengths.
|
||||
# Make sure the sum of the batch contents is correct for the individual
|
||||
# sequence lengths.
|
||||
self.assertEqual(
|
||||
generated_sums[l], expected_sums[l], "Tensor sums did not match! "
|
||||
"expected: {}, generated: {}".format(expected_sums, generated_sums))
|
||||
@ -261,6 +263,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
|
||||
_test_bucket_by_padding(param_no_padding)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPadToBoundary(self):
|
||||
|
||||
boundaries = [10, 20, 30]
|
||||
@ -308,6 +311,7 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
self.assertEqual([boundary - 1 for boundary in sorted(boundaries)],
|
||||
sorted(lengths_val))
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testPadToBoundaryNoExtraneousPadding(self):
|
||||
|
||||
boundaries = [3, 7, 11]
|
||||
@ -460,6 +464,25 @@ class BucketBySequenceLengthTest(test_base.DatasetTestBase,
|
||||
expected_batches = _compute_expected_batches(param_drop_remainder)
|
||||
self.assertEqual(batches, expected_batches)
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testCardinality(self):
|
||||
|
||||
boundaries = [3, 7, 11]
|
||||
batch_sizes = [2, 2, 2, 2]
|
||||
lengths = range(1, 11)
|
||||
|
||||
def element_gen():
|
||||
for length in lengths:
|
||||
yield ([1] * length,)
|
||||
|
||||
element_len = lambda element: array_ops.shape(element)[0]
|
||||
dataset = dataset_ops.Dataset.from_generator(
|
||||
element_gen, (dtypes.int64,), ([None],)).repeat().apply(
|
||||
grouping.bucket_by_sequence_length(
|
||||
element_len, boundaries, batch_sizes,
|
||||
pad_to_bucket_boundary=True))
|
||||
self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.INFINITE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -341,6 +341,14 @@ class GroupByWindowTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
get_next = self.getNext(dataset)
|
||||
self.evaluate(get_next())
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def testGroupByWindowCardinality(self):
|
||||
dataset = dataset_ops.Dataset.range(1).repeat().apply(
|
||||
grouping.group_by_window(
|
||||
lambda x: x % 2,
|
||||
lambda key, window: dataset_ops.Dataset.from_tensors(key), 4))
|
||||
self.assertEqual(self.evaluate(dataset.cardinality()), dataset_ops.INFINITE)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user