diff --git a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc index 30d0f9405f7..69e2cbb318e 100644 --- a/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/assert_cardinality_dataset_op.cc @@ -16,6 +16,7 @@ limitations under the License. #include +#include "tensorflow/core/framework/dataset.h" #include "tensorflow/core/framework/partial_tensor_shape.h" #include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/kernels/data/name_utils.h" @@ -113,7 +114,8 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { ElementString(dataset()->cardinality_), " but contained only ", ElementString(num_elements_), "."); } - if (num_elements_ > dataset()->cardinality_) { + if (dataset()->cardinality_ != kInfiniteCardinality && + num_elements_ > dataset()->cardinality_) { return errors::FailedPrecondition( "Input dataset was expected to contain ", ElementString(dataset()->cardinality_), " but contained at least ", @@ -147,6 +149,9 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase { private: static string ElementString(int64 n) { + if (n == kInfiniteCardinality) { + return strings::StrCat("an infinite number of elements"); + } return strings::StrCat(n, " element", n != 1 ? "s" : ""); } diff --git a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py index 362495744dc..9fa88adff39 100644 --- a/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/assert_cardinality_test.py @@ -53,6 +53,16 @@ class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase): asserted_cardinality=20, expected_error="Input dataset was expected to contain 20 " "elements but contained only 1 element.") + + combinations.combine( + num_elements=10, + asserted_cardinality=cardinality.INFINITE, + expected_error="Input dataset was expected to contain an " + "infinite number of elements but contained only 10 elements.") + + combinations.combine( + num_elements=1, + asserted_cardinality=cardinality.INFINITE, + expected_error="Input dataset was expected to contain an " + "infinite number of elements but contained only 1 element.") + combinations.combine( num_elements=10, asserted_cardinality=5,