diff --git a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc index 5b03174f77c..8acc2f69e1d 100644 --- a/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc +++ b/tensorflow/core/kernels/data/parallel_interleave_dataset_op.cc @@ -219,6 +219,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase { ParallelInterleaveDatasetOp::kDatasetType, params); } + int64 Cardinality() const override { + int64 n = input_->Cardinality(); + if (n == kInfiniteCardinality) { + return n; + } + return kUnknownCardinality; + } + Status InputDatasets(std::vector* inputs) const override { inputs->push_back(input_); return Status::OK(); diff --git a/tensorflow/python/data/kernel_tests/cardinality_test.py b/tensorflow/python/data/kernel_tests/cardinality_test.py index d46b3c47369..cc29893eb90 100644 --- a/tensorflow/python/data/kernel_tests/cardinality_test.py +++ b/tensorflow/python/data/kernel_tests/cardinality_test.py @@ -94,6 +94,10 @@ def _test_combinations(): lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1, num_parallel_calls=1), dataset_ops.UNKNOWN), + ("Interleave3", lambda: dataset_ops.Dataset.range(5).repeat().interleave( + lambda _: dataset_ops.Dataset.from_tensors(0), + cycle_length=1, + num_parallel_calls=1), dataset_ops.INFINITE), ("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch( 2, [], drop_remainder=True), 2), ("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(