From 59f5abfbc8dc5559c361f80f4fa4a006db825e40 Mon Sep 17 00:00:00 2001 From: Andrew Audibert Date: Mon, 28 Dec 2020 13:03:29 -0800 Subject: [PATCH] [tf.data] Support asserting infinite cardinality. Previously using tf.data.experimental.assert_cardinality(tf.data.experimental.INFINITE_CARDINALITY) would cause the assertion to fail as soon as the first dataset element was produced, even if the dataset actually was infinite. After this CL, we will only raise an error if the dataset runs out of elements. Fixes https://github.com/tensorflow/tensorflow/issues/45894 PiperOrigin-RevId: 349321521 Change-Id: I54804225da55f49cef4fa69e498a239854d16e22 --- .../data/experimental/assert_cardinality_dataset_op.cc | 7 ++++++- .../kernel_tests/assert_cardinality_test.py | 10 ++++++++++ 2 files changed, 16 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index 30d0f9405f7..69e2cbb318e 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/name_utils.h" @@ -113,7 +114,8 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { ElementString(dataset()->cardinality_), " but contained only ", ElementString(num_elements_), "."); } - if (num_elements_ > dataset()->cardinality_) { + if (dataset()->cardinality_ != kInfiniteCardinality && + num_elements_ > dataset()->cardinality_) { return errors::FailedPrecondition( "Input dataset was expected to contain ", ElementString(dataset()->cardinality_), " but contained at least ", @@ -147,6 +149,9 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { private: static string ElementString(int64 n) { + if (n == kInfiniteCardinality) { + return strings::StrCat("an infinite number of elements"); + } return strings::StrCat(n, " element", n != 1 ? "s" : ""); } diff --git a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py index 362495744dc..9fa88adff39 100644 --- a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py @@ -53,6 +53,16 @@ class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase): asserted_cardinality=20, expected_error="Input dataset was expected to contain 20 " "elements but contained only 1 element.") + + combinations.combine( + num_elements=10, + asserted_cardinality=cardinality.INFINITE, + expected_error="Input dataset was expected to contain an " + "infinite number of elements but contained only 10 elements.") + + combinations.combine( + num_elements=1, + asserted_cardinality=cardinality.INFINITE, + expected_error="Input dataset was expected to contain an " + "infinite number of elements but contained only 1 element.") + combinations.combine( num_elements=10, asserted_cardinality=5,