[tf.data] Addressing TODOs regarding deprecated APIs.

PiperOrigin-RevId: 299470038
Change-Id: If445f10866b9356e35abbd139a929d5a0f77c0bc
This commit is contained in:
Jiri Simsa 2020-03-06 17:18:48 -08:00 committed by TensorFlower Gardener
parent 4831d4f42f
commit 3874b28288
10 changed files with 75 additions and 69 deletions

View File

@ -619,12 +619,8 @@ class IteratorBase {
// Saves the state of this iterator.
//
// This method is used to store the state of the iterator in a checkpoint.
//
// TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
// implementations have an override.
virtual Status SaveInternal(IteratorStateWriter* writer) {
return errors::Unimplemented("SaveInternal");
}
virtual Status SaveInternal(IteratorStateWriter* writer) = 0;
// Restores the state of this iterator.
//
@ -633,13 +629,9 @@ class IteratorBase {
// Implementations may assume that the iterator is in a clean state. That is,
// its `Initialize` method has been called, but its `GetNext` method has
// never been called.
//
// TODO(jsimsa): Make this method pure virtual once all `IteratorBase`
// implementations have an override.
virtual Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) {
return errors::Unimplemented("RestoreInternal");
}
IteratorStateReader* reader) = 0;
// Returns the number of elements produced by this iterator.
int64 num_elements() const {
@ -749,22 +741,6 @@ class DatasetBase : public core::RefCounted {
return MakeIterator(&ctx, parent, output_prefix, iterator);
}
// TODO(jsimsa): Remove this overlead once all callers are migrated to the API
// that passes in the parent iterator pointer.
ABSL_DEPRECATED("Use the overload that passes the parent iterator pointer.")
Status MakeIterator(IteratorContext* ctx, const string& output_prefix,
std::unique_ptr<IteratorBase>* iterator) const {
return MakeIterator(ctx, /*parent=*/nullptr, output_prefix, iterator);
}
// TODO(jsimsa): Remove this overlead once all callers are migrated to the API
// that passes in the parent iterator pointer.
ABSL_DEPRECATED("Use the overload that passes the parent iterator pointer.")
Status MakeIterator(IteratorContext&& ctx, const string& output_prefix,
std::unique_ptr<IteratorBase>* iterator) const {
return MakeIterator(&ctx, output_prefix, iterator);
}
// Returns a new iterator restored from the checkpoint data in `reader`.
Status MakeIteratorFromCheckpoint(
IteratorContext* ctx, const string& output_prefix,
@ -807,27 +783,11 @@ class DatasetBase : public core::RefCounted {
// A human-readable debug string for this dataset.
virtual string DebugString() const = 0;
// If the dataset is stateful it will not be possible to save its graph or
// checkpoint the state of its iterators.
//
// TODO(jsimsa): Remove this method once all `DatasetBase` implementations are
// migrated over to `CheckExternalState`.
ABSL_DEPRECATED("Use CheckExternalState instead.")
virtual bool IsStateful() const { return false; }
// Indicates whether the dataset depends on any external state which would
// prevent it from being serializable. If so, the method returns
// `errors::FailedPrecondition` with a message that identifies the external
// state. Otherwise, the method returns `Status::OK()`.
//
// TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
// implementations have an override.
virtual Status CheckExternalState() const {
if (IsStateful()) {
return errors::FailedPrecondition("Dataset cannot be serialized.");
}
return Status::OK();
}
virtual Status CheckExternalState() const = 0;
protected:
friend Status AsGraphDef(

View File

@ -432,15 +432,6 @@ Status MakeIteratorFromInputElement(
out_iterator);
}
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator) {
return MakeIteratorFromInputElement(ctx, /*parent=*/nullptr, input_element,
thread_index, inst_captured_func, prefix,
out_iterator);
}
/* static */
Status FunctionMetadata::Create(
OpKernelConstruction* ctx, const string& func_name, Params params,

View File

@ -47,17 +47,6 @@ Status MakeIteratorFromInputElement(
const InstantiatedCapturedFunction& inst_captured_func, StringPiece prefix,
std::unique_ptr<IteratorBase>* out_iterator);
// Creates an iterator for a dataset which is created by applying the given
// function to the given input element.
//
// TODO(jsimsa): Remove this overload once all callers are migrated to the API
// that passes in the parent iterator pointer.
ABSL_DEPRECATED("Use the overload that passes the parent iterator pointer.")
Status MakeIteratorFromInputElement(
IteratorContext* ctx, const std::vector<Tensor>& input_element,
int64 thread_index, const InstantiatedCapturedFunction& inst_captured_func,
StringPiece prefix, std::unique_ptr<IteratorBase>* out_iterator);
// Determines whether the given node is stateful.
Status IsNodeStateful(const FunctionLibraryDefinition& library,
const NodeDef& node);

View File

@ -55,6 +55,8 @@ class WrapperDataset : public DatasetBase {
string DebugString() const override { return "WrapperDataset"; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -64,6 +64,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
int64 Cardinality() const override { return kInfiniteCardinality; }
Status CheckExternalState() const override { return Status::OK(); }
protected:
Status AsGraphDefInternal(SerializationContext* ctx,
DatasetGraphDefBuilder* b,

View File

@ -37,7 +37,7 @@ class RandomDatasetOp : public DatasetOpKernel {
explicit RandomDatasetOp(OpKernelConstruction* ctx);
protected:
virtual void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
void MakeDataset(OpKernelContext* ctx, DatasetBase** output) override;
private:
class Dataset;

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/core/kernels/data/dataset_utils.h"
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/util/work_sharder.h"
namespace tensorflow {
@ -203,6 +204,7 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
mutex_lock l(mu_);
return input_impl_->GetNext(IteratorContext(CreateParams(ctx)),
out_tensors, end_of_sequence);
}
@ -214,6 +216,20 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
DCHECK(input_impl_ != nullptr);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
IteratorContext::Params CreateParams(IteratorContext* ctx) {
ThreadPoolResource* pool = dataset()->threadpool_;
@ -225,7 +241,8 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
return params;
}
std::unique_ptr<IteratorBase> input_impl_;
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
};
const DatasetBase* const input_;
@ -319,6 +336,7 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
auto max_parallelism = dataset()->max_intra_op_parallelism_;
params.runner =
RunnerWithMaxParallelism(*ctx->runner(), max_parallelism);
mutex_lock l(mu_);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
@ -330,8 +348,23 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
DCHECK(input_impl_ != nullptr);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
std::unique_ptr<IteratorBase> input_impl_;
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
};
const DatasetBase* const input_;
@ -425,6 +458,7 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
pool->Schedule(std::move(c));
};
params.runner_threadpool_size = dataset()->num_threads_;
mutex_lock l(mu_);
return input_impl_->GetNext(IteratorContext{std::move(params)},
out_tensors, end_of_sequence);
}
@ -436,8 +470,23 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
mutex_lock l(mu_);
DCHECK(input_impl_ != nullptr);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impl_));
return Status::OK();
}
private:
std::unique_ptr<IteratorBase> input_impl_;
mutex mu_;
std::unique_ptr<IteratorBase> input_impl_ TF_GUARDED_BY(mu_);
};
const DatasetBase* const input_;

View File

@ -158,6 +158,17 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented(
"GeneratorDataset does not support checkpointing.");
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(
"GeneratorDataset does not support checkpointing.");
}
private:
mutex mu_;
bool initialized_ TF_GUARDED_BY(mu_) = false;

View File

@ -573,7 +573,8 @@ class MultiDeviceIteratorInitOp : public OpKernel {
IteratorContext iter_ctx(std::move(params));
OP_REQUIRES_OK(
ctx, dataset->MakeIterator(std::move(iter_ctx), "Iterator", &iterator));
ctx, dataset->MakeIterator(std::move(iter_ctx), /*parent=*/nullptr,
"Iterator", &iterator));
int64 incarnation_id;
OP_REQUIRES_OK(ctx, resource->Init(std::move(iterator), max_buffer_size,
&incarnation_id));

View File

@ -496,7 +496,8 @@ TEST_P(ParameterizedIteratorSaveAndRestoreTest, IteratorSaveAndRestore) {
&window_dataset));
std::unique_ptr<IteratorBase> window_dataset_iterator;
TF_ASSERT_OK(window_dataset->MakeIterator(
iterator_ctx_.get(), test_case.dataset_params.iterator_prefix(),
iterator_ctx_.get(), /*parent=*/nullptr,
test_case.dataset_params.iterator_prefix(),
&window_dataset_iterator));
bool end_of_window_dataset = false;
std::vector<Tensor> window_elements;