[tf.data] Support asserting infinite cardinality.

Previously using tf.data.experimental.assert_cardinality(tf.data.experimental.INFINITE_CARDINALITY) would cause the assertion to fail as soon as the first dataset element was produced, even if the dataset actually was infinite. After this CL, we will only raise an error if the dataset runs out of elements.

Fixes https://github.com/tensorflow/tensorflow/issues/45894

PiperOrigin-RevId: 349321521
Change-Id: I54804225da55f49cef4fa69e498a239854d16e22
This commit is contained in:
Andrew Audibert 2020-12-28 13:03:29 -08:00 committed by TensorFlower Gardener
parent 7f58b07fb9
commit 59f5abfbc8
2 changed files with 16 additions and 1 deletions
tensorflow
core/kernels/data/experimental
python/data/experimental/kernel_tests

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <map>
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/kernels/data/name_utils.h"
@ -113,7 +114,8 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
ElementString(dataset()->cardinality_), " but contained only ",
ElementString(num_elements_), ".");
}
if (num_elements_ > dataset()->cardinality_) {
if (dataset()->cardinality_ != kInfiniteCardinality &&
num_elements_ > dataset()->cardinality_) {
return errors::FailedPrecondition(
"Input dataset was expected to contain ",
ElementString(dataset()->cardinality_), " but contained at least ",
@ -147,6 +149,9 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
private:
static string ElementString(int64 n) {
if (n == kInfiniteCardinality) {
return strings::StrCat("an infinite number of elements");
}
return strings::StrCat(n, " element", n != 1 ? "s" : "");
}

View File

@ -53,6 +53,16 @@ class AssertCardinalityTest(test_base.DatasetTestBase, parameterized.TestCase):
asserted_cardinality=20,
expected_error="Input dataset was expected to contain 20 "
"elements but contained only 1 element.") +
combinations.combine(
num_elements=10,
asserted_cardinality=cardinality.INFINITE,
expected_error="Input dataset was expected to contain an "
"infinite number of elements but contained only 10 elements.") +
combinations.combine(
num_elements=1,
asserted_cardinality=cardinality.INFINITE,
expected_error="Input dataset was expected to contain an "
"infinite number of elements but contained only 1 element.") +
combinations.combine(
num_elements=10,
asserted_cardinality=5,