[tf.data] Adding support for overriding external state policy for checkpointing.
PiperOrigin-RevId: 301443563 Change-Id: I852269b86039a71466ddeadfe3ce03d75dc45fda
This commit is contained in:
parent
83c0affbbd
commit
3a15248de3
@ -481,6 +481,25 @@ class SerializationContext {
|
||||
kFail = 2,
|
||||
};
|
||||
|
||||
// Handles the CheckExternalState status according to the external state
|
||||
// policy.
|
||||
Status HandleCheckExternalStateStatus(Status s) {
|
||||
if (s.ok()) {
|
||||
return s;
|
||||
}
|
||||
switch (params_.external_state_policy) {
|
||||
case ExternalStatePolicy::kWarn:
|
||||
LOG(WARNING) << s.ToString();
|
||||
return Status::OK();
|
||||
case ExternalStatePolicy::kIgnore:
|
||||
VLOG(2) << "Ignoring error status: " << s.ToString();
|
||||
return Status::OK();
|
||||
case ExternalStatePolicy::kFail:
|
||||
return s;
|
||||
}
|
||||
LOG(FATAL) << "Control should never reach here";
|
||||
}
|
||||
|
||||
struct Params {
|
||||
std::vector<std::pair<string, Tensor>>* input_list = nullptr; // Not owned.
|
||||
|
||||
@ -589,7 +608,7 @@ class IteratorBase {
|
||||
|
||||
// Saves the state of this iterator.
|
||||
virtual Status Save(SerializationContext* ctx, IteratorStateWriter* writer) {
|
||||
return SaveInternal(writer);
|
||||
return SaveInternal(ctx, writer);
|
||||
}
|
||||
|
||||
protected:
|
||||
@ -604,9 +623,17 @@ class IteratorBase {
|
||||
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
// `SaveInternal` on their input iterators.
|
||||
Status SaveInput(SerializationContext* ctx, IteratorStateWriter* writer,
|
||||
const std::unique_ptr<IteratorBase>& input) {
|
||||
return input->SaveInternal(ctx, writer);
|
||||
}
|
||||
|
||||
// TODO(jsimsa): Remove this override when all callers are migrated to the
|
||||
// override that uses SerializationContext.
|
||||
Status SaveInput(IteratorStateWriter* writer,
|
||||
const std::unique_ptr<IteratorBase>& input) {
|
||||
return input->SaveInternal(writer);
|
||||
SerializationContext ctx(/*params=*/{});
|
||||
return input->SaveInternal(&ctx, writer);
|
||||
}
|
||||
|
||||
// This is needed so that sub-classes of IteratorBase can call
|
||||
@ -620,7 +647,17 @@ class IteratorBase {
|
||||
//
|
||||
// This method is used to store the state of the iterator in a checkpoint.
|
||||
// implementations have an override.
|
||||
virtual Status SaveInternal(IteratorStateWriter* writer) = 0;
|
||||
virtual Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) {
|
||||
return SaveInternal(writer);
|
||||
}
|
||||
|
||||
// TODO(jsimsa): Remove this override when all subclasses are migrated to the
|
||||
// override that accepts SerializationContext and make that override pure
|
||||
// virtual.
|
||||
virtual Status SaveInternal(IteratorStateWriter* writer) {
|
||||
return errors::Unimplemented("checkpointing is not supported");
|
||||
}
|
||||
|
||||
// Restores the state of this iterator.
|
||||
//
|
||||
|
@ -257,12 +257,13 @@ class BatchDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -164,10 +164,11 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kMode), mode_));
|
||||
return SaveInput(writer, iterator_);
|
||||
return SaveInput(ctx, writer, iterator_);
|
||||
}
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
@ -303,7 +304,8 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCurIndex), cur_index_));
|
||||
@ -333,7 +335,7 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
lockfile_ = strings::StrCat(filename_, kLockFileSuffix);
|
||||
lockfile_created_ = false;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kShardId), shard_id_));
|
||||
return Status::OK();
|
||||
}
|
||||
@ -532,7 +534,8 @@ class CacheDatasetOp::FileDataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCurIndex), cur_index_));
|
||||
@ -785,14 +788,15 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (cache_->IsCompleted()) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCacheCompleted), ""));
|
||||
TF_RETURN_IF_ERROR(SaveCache(
|
||||
writer, cache_, [this](const string& s) { return full_name(s); }));
|
||||
}
|
||||
return SaveInput(writer, iterator_);
|
||||
return SaveInput(ctx, writer, iterator_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
@ -867,14 +871,15 @@ class CacheDatasetOp::MemoryDataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!cache_->IsCompleted()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SaveCache(writer, &temp_cache_,
|
||||
[this](const string& s) { return full_name(s); }));
|
||||
}
|
||||
return SaveInput(writer, input_impl_);
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
|
@ -146,11 +146,12 @@ class ConcatenateDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kIndex), i_));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplUninitialized), ""));
|
||||
|
@ -124,10 +124,11 @@ class AssertCardinalityDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("num_elements"), num_elements_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -119,8 +119,9 @@ class AssertNextDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -105,7 +105,8 @@ class WrapperDataset : public DatasetBase {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1.0);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -393,9 +394,10 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
// TODO(rachelim): Save and restore histogram state as well. Currently,
|
||||
// if an iterator is saved and restored, the histograms start recording
|
||||
// from scratch.
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"),
|
||||
experiment_counter_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -403,7 +405,7 @@ class ChooseFastestBranchDatasetOp : public UnaryDatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("fastest_index"), fastest_index_));
|
||||
if (current_iterator_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, current_iterator_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_iterator_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
|
@ -238,7 +238,8 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
// TODO(rachelim): Save and restore histogram state as well. Currently,
|
||||
// if an iterator is saved and restored, the histograms start recording
|
||||
// from scratch.
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("experiment_counter"),
|
||||
experiment_counter_));
|
||||
@ -246,13 +247,13 @@ class ChooseFastestDatasetOp : public DatasetOpKernel {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("fastest_index"), fastest_index_));
|
||||
if (fastest_index_ != -1) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, fastest_input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, fastest_input_impl_));
|
||||
} else if (input_impls_.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impls_empty"), ""));
|
||||
} else {
|
||||
for (auto& input_impl : input_impls_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -269,7 +269,8 @@ class CSVDatasetOp : public DatasetOpKernel {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("current_file_index"),
|
||||
current_file_index_));
|
||||
|
@ -289,9 +289,10 @@ class DenseToSparseBatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
DatasetIterator<Dataset<T>>::dataset()->batch_size_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(Iterator::SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(Iterator::SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -212,10 +212,11 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
||||
return model::MakeInterleaveManyNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (selector_input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, selector_input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, selector_input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("selector_input_impl_empty"), ""));
|
||||
@ -223,7 +224,7 @@ class DirectedInterleaveDatasetOp : public DatasetOpKernel {
|
||||
for (size_t i = 0; i < data_input_impls_.size(); ++i) {
|
||||
const auto& data_input_impl = data_input_impls_[i];
|
||||
if (data_input_impl) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, data_input_impl));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, data_input_impl));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat("data_input_impl_empty[", i, "]")),
|
||||
|
@ -285,16 +285,18 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
return model::MakeUnknownRatioNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_init_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_reduce_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_finalize_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_key_func_->CheckExternalState()));
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_init_func_->CheckExternalState()));
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_reduce_func_->CheckExternalState()));
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_finalize_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
|
||||
if (end_of_input_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -294,14 +294,16 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
return model::MakeUnknownRatioNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_reduce_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_window_size_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_key_func_->CheckExternalState()));
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_reduce_func_->CheckExternalState()));
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_window_size_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
|
||||
if (end_of_input_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -342,7 +344,7 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
if (current_group_iterator_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, current_group_iterator_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_group_iterator_));
|
||||
|
||||
// Saving current_key_
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -115,10 +115,11 @@ class IgnoreErrorsDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impls_empty"), ""));
|
||||
|
@ -134,7 +134,8 @@ class LMDBDatasetOp : public DatasetOpKernel {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented(
|
||||
"Checkpointing is currently not supported for LMDBDataset.");
|
||||
}
|
||||
|
@ -251,15 +251,17 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
cond_var_->wait(l);
|
||||
}
|
||||
DCHECK_EQ(num_calls_, 0);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCallCounter), call_counter_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kBatchResultsSize),
|
||||
|
@ -106,8 +106,9 @@ class NonSerializableDatasetOp : public UnaryDatasetOpKernel {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -396,13 +396,15 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
/*parameters=*/{});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
// The order of locking is important here to avoid deadlock.
|
||||
mutex_lock l(mu_);
|
||||
mutex_lock ckpt_l(ckpt_mu_);
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInputExhausted, ""));
|
||||
}
|
||||
@ -416,7 +418,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(WriteWorkerStateLocked(writer, i));
|
||||
}
|
||||
for (int i = 0; i < worker_thread_states_.size(); ++i) {
|
||||
TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(writer, i));
|
||||
TF_RETURN_IF_ERROR(WriteWorkerThreadStateLocked(ctx, writer, i));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kInterleaveSize,
|
||||
interleave_indices_.size()));
|
||||
@ -932,13 +934,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteWorkerThreadStateLocked(IteratorStateWriter* writer, int index)
|
||||
Status WriteWorkerThreadStateLocked(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer, int index)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_, ckpt_mu_) {
|
||||
string iterator_name =
|
||||
strings::StrCat(prefix(), "::", kWorkerThread, "_", index);
|
||||
if (worker_thread_states_[index].iterator != nullptr) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
SaveInput(writer, worker_thread_states_[index].iterator));
|
||||
SaveInput(ctx, writer, worker_thread_states_[index].iterator));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(iterator_name, kIteratorExhausted, ""));
|
||||
|
@ -104,7 +104,8 @@ class RandomDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name("num_random_samples"),
|
||||
num_random_samples_));
|
||||
|
@ -179,13 +179,14 @@ class RebatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("slice_number"), slice_number_));
|
||||
|
@ -145,7 +145,8 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
generator_.Skip(num_random_samples_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
// Save state needed to restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
@ -156,7 +157,7 @@ class SamplingDatasetOp::Dataset : public DatasetBase {
|
||||
writer->WriteScalar(this->full_name("seed2"), seeds_.second));
|
||||
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
|
@ -248,10 +248,12 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
if (!state_.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("state_size"), state_.size()));
|
||||
|
@ -199,9 +199,10 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
return SaveInput(writer, input_impl_);
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
|
@ -142,8 +142,9 @@ class SleepDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return SaveInput(writer, input_impl_);
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
|
@ -239,13 +239,14 @@ class SlidingWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
dataset()->window_shift_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
// Save buffer.
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(strings::StrCat("buffer_size"),
|
||||
|
@ -389,9 +389,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, iterator_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, iterator_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kState), static_cast<int64>(state_)));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kHashDir), hash_dir_));
|
||||
@ -623,7 +624,8 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kHashDir), hash_dir_));
|
||||
@ -1037,9 +1039,10 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
if (end_of_sequence_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kEndOfSequence), ""));
|
||||
@ -1482,8 +1485,9 @@ class SnapshotDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return SaveInput(writer, input_impl_);
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return SaveInput(ctx, writer, input_impl_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
|
@ -158,7 +158,8 @@ class SqlDatasetOp : public DatasetOpKernel {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (query_connection_initialized_) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -128,9 +128,10 @@ class LatencyStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -246,9 +247,10 @@ class BytesProducedStatsDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -166,11 +166,13 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impls_empty"), ""));
|
||||
|
@ -216,10 +216,11 @@ class ThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(input_impl_ != nullptr);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -348,10 +349,11 @@ class MaxIntraOpParallelismDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(input_impl_ != nullptr);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -470,10 +472,11 @@ class PrivateThreadPoolDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
DCHECK(input_impl_ != nullptr);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -164,10 +164,11 @@ class UnbatchDatasetOp : public UnaryDatasetOpKernel {
|
||||
return model::MakeUnknownRatioNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
|
@ -103,10 +103,11 @@ class UniqueDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeUnknownRatioNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name("input_impl_empty"), ""));
|
||||
|
@ -194,11 +194,13 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeUnknownRatioNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplsEmpty), ""));
|
||||
|
@ -190,7 +190,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
protected:
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex),
|
||||
current_file_index_));
|
||||
@ -374,7 +375,8 @@ class FixedLengthRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex),
|
||||
current_file_index_));
|
||||
|
@ -162,11 +162,13 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeInterleaveManyNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kElementIndex), element_index_));
|
||||
if (current_element_iterator_) {
|
||||
@ -178,7 +180,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
full_name(strings::StrCat(kCapturedFuncInputs, "[", i, "]")),
|
||||
captured_func_inputs_[i]));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, current_element_iterator_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_element_iterator_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(kCurrentElementIteratorUninitialized), ""));
|
||||
|
@ -158,7 +158,8 @@ class GeneratorDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented(
|
||||
"GeneratorDataset does not support checkpointing.");
|
||||
}
|
||||
|
@ -195,10 +195,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeInterleaveManyNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kCycleIndex), cycle_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -207,7 +209,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kEndOfInput), ""));
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNumOpen), num_open_));
|
||||
TF_RETURN_IF_ERROR(SaveCurrentElements(writer));
|
||||
TF_RETURN_IF_ERROR(SaveCurrentElements(ctx, writer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -234,11 +236,12 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
private:
|
||||
Status SaveCurrentElements(IteratorStateWriter* writer)
|
||||
Status SaveCurrentElements(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (current_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, current_elements_[idx]));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, current_elements_[idx]));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kArgsSize, "[", idx, "]")),
|
||||
args_list_[idx].size()));
|
||||
|
@ -174,9 +174,11 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -151,9 +151,10 @@ class ModelDatasetOp : public UnaryDatasetOpKernel {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -343,10 +343,11 @@ class PaddedBatchDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeKnownRatioNode(std::move(args), dataset()->batch_size_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kExhausted), ""));
|
||||
return Status::OK();
|
||||
|
@ -377,8 +377,10 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
|
||||
// TODO(aaudibert): Refactor the implementations to avoid the need for
|
||||
// `IteratorContext` when saving the state of the iterator.
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
dataset()->captured_func_->CheckExternalState()));
|
||||
mutex_lock l(*mu_);
|
||||
wait_for_checkpoint_ = true;
|
||||
// Wait for all in-flight calls to complete.
|
||||
@ -400,7 +402,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
wait_for_checkpoint_ = false;
|
||||
DCHECK_EQ(num_active_workers_, 0);
|
||||
VLOG(4) << "State before save:\n" << DebugString();
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(prefix(), kBlockIndex, block_index_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -410,8 +412,8 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kElementIdCounter,
|
||||
element_id_counter_));
|
||||
TF_RETURN_IF_ERROR(WriteCurrentElements(writer));
|
||||
TF_RETURN_IF_ERROR(WriteFutureElements(writer));
|
||||
TF_RETURN_IF_ERROR(WriteCurrentElements(ctx, writer));
|
||||
TF_RETURN_IF_ERROR(WriteFutureElements(ctx, writer));
|
||||
// Wake workers back up.
|
||||
current_workers_cond_var_.notify_all();
|
||||
future_workers_cond_var_.notify_all();
|
||||
@ -1124,13 +1126,14 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return absl::StrCat(kResultsSuffix, "[", idx, "]", kErrorMessageSuffix);
|
||||
}
|
||||
|
||||
Status WriteElement(std::shared_ptr<Element> element, int idx,
|
||||
Status WriteElement(SerializationContext* ctx,
|
||||
std::shared_ptr<Element> element, int idx,
|
||||
const string& key_prefix, IteratorStateWriter* writer)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(*mu_) {
|
||||
const auto& iterator_name =
|
||||
absl::StrCat(prefix(), "::", key_prefix, "::", idx);
|
||||
if (element->iterator) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, element->iterator));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, element->iterator));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(iterator_name, kIdSuffix, element->id));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
@ -1165,26 +1168,28 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteCurrentElements(IteratorStateWriter* writer)
|
||||
Status WriteCurrentElements(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kCurrentElementsSize,
|
||||
current_elements_.size()));
|
||||
for (int idx = 0; idx < current_elements_.size(); idx++) {
|
||||
if (current_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(WriteElement(current_elements_[idx], idx,
|
||||
TF_RETURN_IF_ERROR(WriteElement(ctx, current_elements_[idx], idx,
|
||||
kCurrentElements, writer));
|
||||
}
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status WriteFutureElements(IteratorStateWriter* writer)
|
||||
Status WriteFutureElements(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer)
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(prefix(), kFutureElementsSize,
|
||||
future_elements_.size()));
|
||||
for (int idx = 0; idx < future_elements_.size(); idx++) {
|
||||
if (future_elements_[idx]) {
|
||||
TF_RETURN_IF_ERROR(WriteElement(future_elements_[idx], idx,
|
||||
TF_RETURN_IF_ERROR(WriteElement(ctx, future_elements_[idx], idx,
|
||||
kFutureElements, writer));
|
||||
}
|
||||
}
|
||||
|
@ -377,8 +377,10 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
/*max=*/ctx->runner_threadpool_size())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(parallel_map_functor_->CheckExternalState());
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(ctx->HandleCheckExternalStateStatus(
|
||||
parallel_map_functor_->CheckExternalState()));
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
@ -388,7 +390,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
return errors::FailedPrecondition(
|
||||
"Unexpected outstanding calls encountered.");
|
||||
}
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(
|
||||
full_name(strings::StrCat(kInvocationResults, kSizeSuffix)),
|
||||
invocation_results_.size()));
|
||||
|
@ -206,12 +206,13 @@ class PrefetchDatasetOp::Dataset : public DatasetBase {
|
||||
/*max=*/std::numeric_limits<int64>::max())});
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
// Acquire both locks to ensure that the prefetch thread and
|
||||
// all GetNext threads are blocked.
|
||||
mutex_lock input_l(input_mu_);
|
||||
mutex_lock l(*mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(prefix(), kBufferSize, buffer_.size()));
|
||||
for (size_t i = 0; i < buffer_.size(); i++) {
|
||||
|
@ -132,7 +132,8 @@ class RangeDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kNext), next_));
|
||||
return Status::OK();
|
||||
|
@ -124,7 +124,8 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/kKnownRatio);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return Status::OK();
|
||||
}
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
@ -172,13 +173,14 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/kKnownRatio);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIteration), i_));
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
@ -249,10 +251,11 @@ class RepeatDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/kKnownRatio);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!first_call_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
else
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kUninitialized), ""));
|
||||
return Status::OK();
|
||||
|
@ -169,12 +169,13 @@ class ShardDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeKnownRatioNode(std::move(args), dataset()->num_shards_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kNextIndex), next_index_));
|
||||
}
|
||||
|
@ -278,7 +278,8 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
generator_.Skip(num_random_samples_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
// Save state needed to restore the random number generators.
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(this->full_name(kNumRandomSamples),
|
||||
@ -292,7 +293,7 @@ class ShuffleDatasetOpBase::ShuffleDatasetBase : public DatasetBase {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(this->full_name(kEndOfInputSequence), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(this->SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(this->SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
|
||||
// Save the epoch counter, buffer, and buffer slices.
|
||||
@ -526,14 +527,15 @@ class ShuffleDatasetOp::Dataset : public ShuffleDatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
// Save RNG state of Dataset.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kDSNumRandomSamples),
|
||||
seed_generator_->num_random_samples()));
|
||||
|
||||
// Save the Iterator.
|
||||
return ShuffleDatasetBase::Iterator<Dataset>::SaveInternal(writer);
|
||||
return ShuffleDatasetBase::Iterator<Dataset>::SaveInternal(ctx, writer);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
@ -634,14 +636,15 @@ class ShuffleDatasetOp::DatasetV2 : public ShuffleDatasetBase {
|
||||
return model::MakeKnownRatioNode(std::move(args), /*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
// Save state of the seed generator.
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kDSNumRandomSamples),
|
||||
seed_generator_->num_random_samples()));
|
||||
|
||||
// Save the tterator state.
|
||||
return ShuffleDatasetBase::Iterator<DatasetV2>::SaveInternal(writer);
|
||||
return ShuffleDatasetBase::Iterator<DatasetV2>::SaveInternal(ctx, writer);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
|
@ -110,7 +110,8 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -172,11 +173,12 @@ class SkipDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
}
|
||||
|
@ -161,7 +161,8 @@ class Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(Iterator::full_name("i"), i_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -95,7 +95,8 @@ class TakeDataset::EmptyIterator : public DatasetIterator<TakeDataset> {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -142,11 +143,12 @@ class TakeDataset::FiniteIterator : public DatasetIterator<TakeDataset> {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
}
|
||||
|
@ -117,7 +117,8 @@ class TensorDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (produced_)
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kProduced), ""));
|
||||
|
@ -136,7 +136,8 @@ class TensorSliceDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
return Status::OK();
|
||||
|
@ -139,7 +139,8 @@ class TextLineDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex),
|
||||
current_file_index_));
|
||||
|
@ -156,7 +156,8 @@ class TFRecordDatasetOp::Dataset : public DatasetBase {
|
||||
return model::MakeSourceNode(std::move(args));
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurrentFileIndex),
|
||||
current_file_index_));
|
||||
|
@ -97,7 +97,8 @@ class WindowDataset : public DatasetBase {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kCurIndex), i_));
|
||||
return Status::OK();
|
||||
|
@ -255,12 +255,13 @@ class WindowDatasetOp::Dataset : public DatasetBase {
|
||||
dataset()->window_shift_);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (!input_impl_) {
|
||||
TF_RETURN_IF_ERROR(writer->WriteScalar(full_name(kInputImplEmpty), ""));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl_));
|
||||
}
|
||||
// Save buffer.
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -174,14 +174,15 @@ class ZipDatasetOp::Dataset : public DatasetBase {
|
||||
/*ratio=*/1);
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
Status SaveInternal(SerializationContext* ctx,
|
||||
IteratorStateWriter* writer) override {
|
||||
mutex_lock l(mu_);
|
||||
if (input_impls_.empty()) {
|
||||
TF_RETURN_IF_ERROR(
|
||||
writer->WriteScalar(full_name(kInputImplsEmpty), ""));
|
||||
} else {
|
||||
for (auto& input_impl : input_impls_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl));
|
||||
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impl));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -258,8 +258,7 @@ class AutoShardDatasetTest(reader_dataset_ops_test_base.TFRecordDatasetTestBase,
|
||||
|
||||
graph_def = dataset._as_serialized_graph(
|
||||
strip_device_assignment=True,
|
||||
external_state_policy=
|
||||
dataset.options().experimental_external_state_policy)
|
||||
external_state_policy=distribute_options.ExternalStatePolicy.WARN)
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_distribute.auto_shard_policy = sharding_policy
|
||||
|
@ -19,6 +19,7 @@ from __future__ import print_function
|
||||
|
||||
from tensorflow.python.compat import compat
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import AutoShardPolicy
|
||||
from tensorflow.python.data.experimental.ops.distribute_options import ExternalStatePolicy
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.data.util import nest
|
||||
from tensorflow.python.framework import ops
|
||||
@ -162,10 +163,12 @@ def replicate(dataset, devices):
|
||||
|
||||
with ops.colocate_with(dataset._variant_tensor):
|
||||
dataset = dataset._apply_options()
|
||||
external_state_policy = dataset.options().experimental_external_state_policy
|
||||
policy = dataset.options().experimental_external_state_policy
|
||||
if policy is None:
|
||||
policy = ExternalStatePolicy.WARN
|
||||
graph_def = dataset._as_serialized_graph(
|
||||
strip_device_assignment=True,
|
||||
external_state_policy=external_state_policy)
|
||||
external_state_policy=policy)
|
||||
for device in devices:
|
||||
ds = _RemoteDataset(graph_def, device, dataset.element_spec)
|
||||
datasets[device] = ds
|
||||
|
@ -20,6 +20,7 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
from tensorflow.python.data.experimental.ops import distribute_options
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import scan_ops
|
||||
@ -35,6 +36,7 @@ from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import script_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -531,6 +533,36 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
dataset = dataset.apply(take_while_ops.take_while(self._statefulBoolFunc))
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulExternalPolicy(self):
|
||||
checkpoint_directory = self.get_temp_dir()
|
||||
checkpoint_prefix = os.path.join(checkpoint_directory, "ckpt")
|
||||
dataset = dataset_ops.Dataset.range(4)
|
||||
|
||||
def fn(x):
|
||||
return x * x
|
||||
|
||||
dataset = dataset.map(
|
||||
lambda x: script_ops.eager_py_func(fn, [x], dtypes.int64))
|
||||
|
||||
options = dataset_ops.Options()
|
||||
options.experimental_external_state_policy = (
|
||||
distribute_options.ExternalStatePolicy.WARN)
|
||||
dataset = dataset.with_options(options)
|
||||
|
||||
iterator = iter(dataset)
|
||||
get_next = iterator.get_next
|
||||
checkpoint = trackable_utils.Checkpoint(iterator=iterator)
|
||||
self.assertEqual(0, get_next().numpy())
|
||||
self.assertEqual(1, get_next().numpy())
|
||||
save_path = checkpoint.save(checkpoint_prefix)
|
||||
self.assertEqual(4, get_next().numpy())
|
||||
self.assertEqual(9, get_next().numpy())
|
||||
checkpoint.restore(save_path).run_restore_ops()
|
||||
self.assertEqual(4, get_next().numpy())
|
||||
self.assertEqual(9, get_next().numpy())
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
get_next()
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
@ -225,9 +225,7 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
serialized graph.
|
||||
"""
|
||||
if external_state_policy:
|
||||
policy = None
|
||||
if external_state_policy:
|
||||
policy = external_state_policy.value
|
||||
policy = external_state_policy.value
|
||||
return gen_dataset_ops.dataset_to_graph_v2(
|
||||
self._variant_tensor,
|
||||
external_state_policy=policy,
|
||||
@ -1031,14 +1029,14 @@ class DatasetV2(tracking_base.Trackable, composite_tensor.CompositeTensor):
|
||||
|
||||
Example:
|
||||
If we had the following files on our filesystem:
|
||||
|
||||
|
||||
- /path/to/dir/a.txt
|
||||
- /path/to/dir/b.py
|
||||
- /path/to/dir/c.py
|
||||
|
||||
|
||||
If we pass "/path/to/dir/*.py" as the directory, the dataset
|
||||
would produce:
|
||||
|
||||
|
||||
- /path/to/dir/b.py
|
||||
- /path/to/dir/c.py
|
||||
|
||||
@ -2731,15 +2729,12 @@ class Options(options_lib.OptionsBase):
|
||||
experimental_external_state_policy = options_lib.create_option(
|
||||
name="experimental_external_state_policy",
|
||||
ty=distribute_options.ExternalStatePolicy,
|
||||
docstring="By default, tf.data will refuse to serialize a dataset or "
|
||||
"checkpoint its iterator if the dataset contains a stateful op as the "
|
||||
"serialization / checkpointing won't be able to capture its state. "
|
||||
"Users can -- at their own risk -- override this restriction by "
|
||||
"explicitly specifying that they are fine throwing away the state "
|
||||
"in these ops. There are three settings available - IGNORE: in which we"
|
||||
"completely ignore any state; WARN: We warn the user that some state "
|
||||
"might be thrown away; FAIL: We fail if any state is being captured.",
|
||||
default_factory=lambda: distribute_options.ExternalStatePolicy.WARN)
|
||||
docstring="This option can be used to override the default policy for "
|
||||
"how to handle external state when serializing a dataset or "
|
||||
"checkpointing its iterator. There are three settings available - "
|
||||
"IGNORE: in which we completely ignore any state; WARN: We warn the "
|
||||
"user that some state might be thrown away; FAIL: We fail if any state "
|
||||
"is being captured.")
|
||||
|
||||
def _graph_rewrites(self):
|
||||
"""Produces the list of enabled static graph rewrites."""
|
||||
|
@ -743,7 +743,17 @@ class OwnedIterator(trackable.Trackable, composite_tensor.CompositeTensor):
|
||||
def _gather_saveables_for_checkpoint(self):
|
||||
|
||||
def _saveable_factory(name):
|
||||
return _IteratorSaveable(self._iterator_resource, name)
|
||||
"""Returns a SaveableObject for serialization/deserialization."""
|
||||
policy = None
|
||||
if self._dataset:
|
||||
policy = self._dataset.options().experimental_external_state_policy
|
||||
if policy:
|
||||
return _IteratorSaveable(
|
||||
self._iterator_resource,
|
||||
name,
|
||||
external_state_policy=policy)
|
||||
else:
|
||||
return _IteratorSaveable(self._iterator_resource, name)
|
||||
|
||||
return {"ITERATOR": _saveable_factory}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user