Preserving infinite cardinality information for dataset.interleave
transformation.
PiperOrigin-RevId: 341064372 Change-Id: I628f362366d52286118753e7a64a334204daf739
This commit is contained in:
parent
3300d79e75
commit
5ab2ad101e
@ -219,6 +219,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
|||||||
ParallelInterleaveDatasetOp::kDatasetType, params);
|
ParallelInterleaveDatasetOp::kDatasetType, params);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
int64 Cardinality() const override {
|
||||||
|
int64 n = input_->Cardinality();
|
||||||
|
if (n == kInfiniteCardinality) {
|
||||||
|
return n;
|
||||||
|
}
|
||||||
|
return kUnknownCardinality;
|
||||||
|
}
|
||||||
|
|
||||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||||
inputs->push_back(input_);
|
inputs->push_back(input_);
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -94,6 +94,10 @@ def _test_combinations():
|
|||||||
lambda _: dataset_ops.Dataset.from_tensors(0),
|
lambda _: dataset_ops.Dataset.from_tensors(0),
|
||||||
cycle_length=1,
|
cycle_length=1,
|
||||||
num_parallel_calls=1), dataset_ops.UNKNOWN),
|
num_parallel_calls=1), dataset_ops.UNKNOWN),
|
||||||
|
("Interleave3", lambda: dataset_ops.Dataset.range(5).repeat().interleave(
|
||||||
|
lambda _: dataset_ops.Dataset.from_tensors(0),
|
||||||
|
cycle_length=1,
|
||||||
|
num_parallel_calls=1), dataset_ops.INFINITE),
|
||||||
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
|
("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
|
||||||
2, [], drop_remainder=True), 2),
|
2, [], drop_remainder=True), 2),
|
||||||
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
|
("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
|
||||||
|
Loading…
Reference in New Issue
Block a user