[tf.data] Check cycle length when restoring parallel interleave iterator.

If we try to restore into an iterator with a smaller cycle length from the original, it will produce a segmentation fault. This can happen either due to user error, or due to the cycle_length being autotuned.

This CL is a stopgap solution to give a better error message than a segmentation fault. In the long term we aim to support adjusting the cycle_length so that autotuned cycle_length + checkpointing just works.

PiperOrigin-RevId: 342733442
Change-Id: Ie9869224cc1598e74e6eb00397df35e6a1a46859
This commit is contained in:
Andrew Audibert 2020-11-16 15:23:37 -08:00 committed by TensorFlower Gardener
parent 3a3063e39a
commit abe233392e

View File

@ -1317,7 +1317,20 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
mutex_lock l(*mu_);
TF_RETURN_IF_ERROR(
reader->ReadScalar(prefix(), kCurrentElementsSize, &size));
DCHECK_EQ(current_elements_.size(), size);
if (current_elements_.size() != size) {
// This could mean two things: (1) the user created their checkpoint
// from a dataset with one cycle_length, then changed the cycle_length
// and tried to restore from the old checkpoint, or (2) the user set
// the cycle length to tf.data.AUTOTUNE, wrote the checkpoint from one
// machine, then tried to restore the checkpoint on another machine
// with a different CPU budget (causing autotune to pick a different
// cycle length).
return errors::FailedPrecondition(
"The iterator cycle length ", current_elements_.size(),
" is different from the cycle length to restore from the "
"checkpoint: ",
size);
}
}
if (size == 0) {
return Status::OK();