Preserving infinite cardinality information for dataset.interleave transformation.

PiperOrigin-RevId: 341064372
Change-Id: I628f362366d52286118753e7a64a334204daf739
This commit is contained in:
Ruoxin Sang 2020-11-06 09:37:25 -08:00 committed by TensorFlower Gardener
parent 3300d79e75
commit 5ab2ad101e
2 changed files with 12 additions and 0 deletions

View File

@ -219,6 +219,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
ParallelInterleaveDatasetOp::kDatasetType, params);
}
int64 Cardinality() const override {
int64 n = input_->Cardinality();
if (n == kInfiniteCardinality) {
return n;
}
return kUnknownCardinality;
}
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
inputs->push_back(input_);
return Status::OK();

View File

@ -94,6 +94,10 @@ def _test_combinations():
lambda _: dataset_ops.Dataset.from_tensors(0),
cycle_length=1,
num_parallel_calls=1), dataset_ops.UNKNOWN),
("Interleave3", lambda: dataset_ops.Dataset.range(5).repeat().interleave(
lambda _: dataset_ops.Dataset.from_tensors(0),
cycle_length=1,
num_parallel_calls=1), dataset_ops.INFINITE),
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
2, [], drop_remainder=True), 2),
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(