[tf.data] Implement DatasetBase::InputDatasets() for core datasets.
PiperOrigin-RevId: 332063322 Change-Id: Ib776466a1f2844b98e8257b28eadb26af9abb9ed
This commit is contained in:
parent
9fd447d0be
commit
a58bfeef15
tensorflow/core/kernels/data
batch_dataset_op.cccache_dataset_ops.ccconcatenate_dataset_op.cc
experimental
assert_cardinality_dataset_op.ccassert_next_dataset_op.ccchoose_fastest_branch_dataset_op.ccchoose_fastest_dataset_op.ccdense_to_sparse_batch_dataset_op.ccgroup_by_window_dataset_op.ccignore_errors_dataset_op.cclmdb_dataset_op.ccmap_and_batch_dataset_op.ccmatching_files_dataset_op.ccnon_serializable_dataset_op.ccparallel_interleave_dataset_op.ccparse_example_dataset_op.ccrandom_dataset_op.ccsampling_dataset_op.ccscan_dataset_op.ccset_stats_aggregator_dataset_op.ccsleep_dataset_op.ccsliding_window_dataset_op.ccsnapshot_dataset_op.ccsnapshot_util.ccsql_dataset_op.ccstats_dataset_ops.cctake_while_dataset_op.ccthreadpool_dataset_op.ccunbatch_dataset_op.ccunique_dataset_op.cc
filter_dataset_op.ccfixed_length_record_dataset_op.ccflat_map_dataset_op.ccgenerator_dataset_op.ccinterleave_dataset_op.ccmap_dataset_op.ccpadded_batch_dataset_op.ccparallel_interleave_dataset_op.ccparallel_map_dataset_op.ccprefetch_dataset_op.ccrepeat_dataset_op.ccshard_dataset_op.ccshuffle_dataset_op.ccskip_dataset_op.ccsparse_tensor_slice_dataset_op.cctake_dataset_op.cctake_dataset_op.htensor_dataset_op.cctensor_slice_dataset_op.cctext_line_dataset_op.cctf_record_dataset_op.ccwindow_dataset.ccwindow_dataset_op.cczip_dataset_op.cc@ -117,6 +117,11 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -102,6 +102,11 @@ class CacheDatasetOp::FileDatasetBase : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
@ -680,6 +685,11 @@ class CacheDatasetOp::MemoryDatasetBase : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -85,6 +85,12 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
||||
return n1 + n2;
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
inputs->push_back(to_concatenate_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(input_->CheckExternalState());
|
||||
return to_concatenate_->CheckExternalState();
|
||||
|
@ -67,6 +67,11 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return cardinality_; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -64,6 +64,11 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -55,6 +55,10 @@ class WrapperDataset : public DatasetBase {
|
||||
|
||||
string DebugString() const override { return "WrapperDataset"; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
@ -245,6 +249,12 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return static_cast<double>(n) * ratio_numerator_ / ratio_denominator_;
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& captured_func : captured_funcs_) {
|
||||
TF_RETURN_IF_ERROR(captured_func->CheckExternalState());
|
||||
|
@ -158,6 +158,14 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return cardinality_; }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
for (const auto& input : inputs_) {
|
||||
inputs->push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& input : inputs_) {
|
||||
TF_RETURN_IF_ERROR(input->CheckExternalState());
|
||||
|
@ -120,6 +120,12 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 ? 0 : 1);
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -107,6 +107,12 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "GroupByWindowDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(captured_reduce_func_->CheckExternalState());
|
||||
|
@ -66,6 +66,12 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -54,6 +54,10 @@ class LMDBDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
string DebugString() const override { return "LMDBDatasetOp::Dataset"; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -137,6 +137,11 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -82,6 +82,11 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
|
||||
return "MatchingFilesDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -69,6 +69,12 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
|
||||
return "NonSerializableDatasetOp::Dataset";
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -147,6 +147,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -270,6 +270,12 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -64,6 +64,10 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return kInfiniteCardinality; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -70,6 +70,11 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -114,6 +114,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -138,6 +138,12 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -68,6 +68,12 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -108,6 +108,12 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
return n / window_shift_;
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -122,6 +122,8 @@ class SnapshotDatasetV2Op::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override;
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override;
|
||||
|
||||
Status CheckExternalState() const override;
|
||||
|
||||
protected:
|
||||
@ -327,6 +329,12 @@ int64 SnapshotDatasetV2Op::Dataset::Cardinality() const {
|
||||
return input_->Cardinality();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SnapshotDatasetV2Op::Dataset::CheckExternalState() const {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
@ -1034,6 +1042,12 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -363,6 +363,10 @@ class Reader::Dataset : public DatasetBase {
|
||||
return "snapshot_util::Reader::Dataset";
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
@ -483,6 +487,11 @@ class Reader::NestedDataset : public DatasetBase {
|
||||
return "snapshot_util::Reader::NestedDataset";
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->clear();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -103,6 +103,11 @@ class SqlDatasetOp : public DatasetOpKernel {
|
||||
|
||||
string DebugString() const override { return "SqlDatasetOp::Dataset"; }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -78,6 +78,12 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -82,6 +82,12 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return kUnknownCardinality; }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -173,6 +173,12 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
@ -301,6 +307,13 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->clear();
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
@ -423,6 +436,13 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->clear();
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -84,6 +84,12 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return kUnknownCardinality;
|
||||
}
|
||||
|
||||
Status InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -54,6 +54,11 @@ class UniqueDatasetOp::Dataset : public DatasetBase {
|
||||
return strings::StrCat("UniqueDatasetOp::Dataset");
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -75,6 +75,11 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -93,6 +93,10 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType, params);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -75,6 +75,11 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -74,6 +74,10 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(init_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(next_func_->CheckExternalState());
|
||||
|
@ -87,6 +87,11 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -80,6 +80,11 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -119,6 +119,11 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
||||
return n / batch_size_ + (n % batch_size_ == 0 || drop_remainder_ ? 0 : 1);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -219,6 +219,11 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
ParallelInterleaveDatasetOp::kDatasetType, params);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -120,6 +120,11 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
TF_RETURN_IF_ERROR(captured_func_->CheckExternalState());
|
||||
return input_->CheckExternalState();
|
||||
|
@ -88,6 +88,11 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return input_->Cardinality(); }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -89,6 +89,11 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
return count_ * n;
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
@ -158,6 +163,9 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
++i_;
|
||||
if (ctx->split_provider()) {
|
||||
TF_RETURN_IF_ERROR(ctx->split_provider()->Reset());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->input_->MakeIterator(ctx, this, prefix(), &input_impl_));
|
||||
}
|
||||
|
@ -84,6 +84,11 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
||||
return n / num_shards_ + (index_ < n % num_shards_ ? 1 : 0);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -117,6 +117,11 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
}
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
@ -182,7 +187,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
data_produced_ = true;
|
||||
break;
|
||||
}
|
||||
if (!data_produced_ && this->dataset()->count_ == -1) {
|
||||
if (ctx->split_provider() == nullptr && !data_produced_ &&
|
||||
this->dataset()->count_ == -1) {
|
||||
// If we encounter the end of sequence without producing data, we
|
||||
// terminate the iteration immediately. (Otherwise, this iterator
|
||||
// would loop infinitely and never produce a value.)
|
||||
@ -192,6 +198,9 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
epoch_++;
|
||||
int64 n = slices_.back()->end;
|
||||
slices_.push_back(absl::make_unique<Slice>(n, n));
|
||||
if (ctx->split_provider()) {
|
||||
TF_RETURN_IF_ERROR(ctx->split_provider()->Reset());
|
||||
}
|
||||
TF_RETURN_IF_ERROR(this->dataset()->input_->MakeIterator(
|
||||
ctx, this, this->prefix(), &input_impl_));
|
||||
}
|
||||
|
@ -75,6 +75,11 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
||||
return count_ < 0 ? 0 : std::max(int64{0}, n - count_);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -56,6 +56,10 @@ class Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return sparse_tensor_.shape()[0]; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -74,6 +74,12 @@ int64 TakeDataset::Cardinality() const {
|
||||
return std::min(n, count_);
|
||||
}
|
||||
|
||||
Status TakeDataset::InputDatasets(
|
||||
std::vector<const DatasetBase*>* inputs) const {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TakeDataset::CheckExternalState() const {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -40,6 +40,8 @@ class TakeDataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override;
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override;
|
||||
|
||||
Status CheckExternalState() const override;
|
||||
|
||||
protected:
|
||||
|
@ -64,6 +64,10 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return 1LL; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -69,6 +69,10 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
int64 Cardinality() const override { return tensors_[0].dim_size(0); }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -70,6 +70,10 @@ class TextLineDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -85,6 +85,10 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return name_utils::DatasetDebugString(kDatasetType);
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -67,6 +67,10 @@ class WindowDataset : public DatasetBase {
|
||||
|
||||
string DebugString() const override { return kWindowDataset; }
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override { return Status::OK(); }
|
||||
|
||||
protected:
|
||||
|
@ -105,6 +105,11 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
||||
return cardinality;
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
inputs->push_back(input_);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
return input_->CheckExternalState();
|
||||
}
|
||||
|
@ -87,6 +87,13 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
||||
return result;
|
||||
}
|
||||
|
||||
Status InputDatasets(std::vector<const DatasetBase*>* inputs) const override {
|
||||
for (const auto& input : inputs_) {
|
||||
inputs->push_back(input);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status CheckExternalState() const override {
|
||||
for (const auto& input : inputs_) {
|
||||
TF_RETURN_IF_ERROR(input->CheckExternalState());
|
||||
|
Loading…
Reference in New Issue
Block a user