Preserving infinite cardinality information for dataset.interleave
transformation.
PiperOrigin-RevId: 341064372 Change-Id: I628f362366d52286118753e7a64a334204daf739
This commit is contained in:
parent
3300d79e75
commit
5ab2ad101e
@ -219,6 +219,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
ParallelInterleaveDatasetOp::kDatasetType, params);
|
||||
}
|
||||
|
||||
int64 Cardinality() const override {
|
||||
int64 n = input_->Cardinality();
|
||||
if (n == kInfiniteCardinality) {
|
||||
return n;
|
||||
}
|
||||
return kUnknownCardinality;
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
|
@ -94,6 +94,10 @@ def _test_combinations():
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0),
|
||||
cycle_length=1,
|
||||
num_parallel_calls=1), dataset_ops.UNKNOWN),
|
||||
("Interleave3", lambda: dataset_ops.Dataset.range(5).repeat().interleave(
|
||||
lambda _: dataset_ops.Dataset.from_tensors(0),
|
||||
cycle_length=1,
|
||||
num_parallel_calls=1), dataset_ops.INFINITE),
|
||||
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
|
||||
2, [], drop_remainder=True), 2),
|
||||
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
|
||||
|
Loading…
Reference in New Issue
Block a user