[tf.data] Completing migration to new internal APIs that make it possible to overriding policy for handling external state during iterator checkpointing.
PiperOrigin-RevId: 302531704 Change-Id: I6a235f9558c948d42acbb771e831ce9adb9d6c8e
This commit is contained in:
parent
c08ebde007
commit
8abb0d2992
@ -628,14 +628,6 @@ class IteratorBase {
|
|||||||
return input->SaveInternal(ctx, writer);
|
return input->SaveInternal(ctx, writer);
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO(jsimsa): Remove this override when all callers are migrated to the
|
|
||||||
// override that uses SerializationContext.
|
|
||||||
Status SaveInput(IteratorStateWriter* writer,
|
|
||||||
const std::unique_ptr<IteratorBase>& input) {
|
|
||||||
SerializationContext ctx(/*params=*/{});
|
|
||||||
return input->SaveInternal(&ctx, writer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// This is needed so that sub-classes of IteratorBase can call
|
// This is needed so that sub-classes of IteratorBase can call
|
||||||
// `RestoreInternal` on their input iterators.
|
// `RestoreInternal` on their input iterators.
|
||||||
Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
|
Status RestoreInput(IteratorContext* ctx, IteratorStateReader* reader,
|
||||||
@ -648,16 +640,7 @@ class IteratorBase {
|
|||||||
// This method is used to store the state of the iterator in a checkpoint.
|
// This method is used to store the state of the iterator in a checkpoint.
|
||||||
// implementations have an override.
|
// implementations have an override.
|
||||||
virtual Status SaveInternal(SerializationContext* ctx,
|
virtual Status SaveInternal(SerializationContext* ctx,
|
||||||
IteratorStateWriter* writer) {
|
IteratorStateWriter* writer) = 0;
|
||||||
return SaveInternal(writer);
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO(jsimsa): Remove this override when all subclasses are migrated to the
|
|
||||||
// override that accepts SerializationContext and make that override pure
|
|
||||||
// virtual.
|
|
||||||
virtual Status SaveInternal(IteratorStateWriter* writer) {
|
|
||||||
return errors::Unimplemented("checkpointing is not supported");
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restores the state of this iterator.
|
// Restores the state of this iterator.
|
||||||
//
|
//
|
||||||
|
@ -944,7 +944,8 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
|||||||
/*ratio=*/1);
|
/*ratio=*/1);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
|
IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
|
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), index_));
|
||||||
return Status::OK();
|
return Status::OK();
|
||||||
|
@ -192,7 +192,8 @@ class MatchingFilesDatasetOp : public DatasetOpKernel {
|
|||||||
return model::MakeSourceNode(std::move(args));
|
return model::MakeSourceNode(std::move(args));
|
||||||
}
|
}
|
||||||
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(SerializationContext* ctx,
|
||||||
|
IteratorStateWriter* writer) override {
|
||||||
mutex_lock l(mu_);
|
mutex_lock l(mu_);
|
||||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||||
full_name("current_pattern_index"), current_pattern_index_));
|
full_name("current_pattern_index"), current_pattern_index_));
|
||||||
|
Loading…
Reference in New Issue
Block a user