[tf.data] Completing migration to new internal APIs that make it possible to overriding policy for handling external state during iterator checkpointing.

PiperOrigin-RevId: 302531704
Change-Id: I6a235f9558c948d42acbb771e831ce9adb9d6c8e
This commit is contained in:
Jiri Simsa 2020-03-23 15:30:38 -07:00 committed by TensorFlower Gardener
parent c08ebde007
commit 8abb0d2992
3 changed files with 5 additions and 20 deletions

View File

@ -628,14 +628,6 @@ class IteratorBase {
return input->SaveInternal(ctx, writer);
}
// TODO(jsimsa): Remove this override when all callers are migrated to the
// override that uses SerializationContext.
Status SaveInput(IteratorStateWriter* writer,
const std::unique_ptr<IteratorBase>& input) {
SerializationContext ctx(/*params=*/{});
return input->SaveInternal(&ctx, writer);
}
// This is needed so that sub-classes of IteratorBase can call
// `RestoreInternal` on their input iterators.
Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
@ -648,16 +640,7 @@ class IteratorBase {
// This method is used to store the state of the iterator in a checkpoint.
// implementations have an override.
virtual Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) {
return SaveInternal(writer);
}
// TODO(jsimsa): Remove this override when all subclasses are migrated to the
// override that accepts SerializationContext and make that override pure
// virtual.
virtual Status SaveInternal(IteratorStateWriter* writer) {
return errors::Unimplemented("checkpointing is not supported");
}
IteratorStateWriter* writer) = 0;
// Restores the state of this iterator.
//

View File

@ -944,7 +944,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
/*ratio=*/1);
}
Status SaveInternal(IteratorStateWriter* writer) override {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
return Status::OK();

View File

@ -192,7 +192,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
return model::MakeSourceNode(std::move(args));
}
Status SaveInternal(IteratorStateWriter* writer) override {
Status SaveInternal(SerializationContext* ctx,
IteratorStateWriter* writer) override {
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(writer->WriteScalar(
full_name("current_pattern_index"), current_pattern_index_));