Enable checkpointing for ReshuffleDataset.

Previously, we would call CheckExternalState before checkpointing. But this was to strong a check, since some external state, such as TF resources, can be checkpointed. Instead of calling CheckExternalState, we now leave it up to Datasets' SaveInternal implementations to return an error if they contain non-checkpointable state.

PiperOrigin-RevId: 299012205
Change-Id: I8638d15073758cd0070bdd4879202ad202a99850
This commit is contained in:
Andrew Audibert 2020-03-04 21:09:05 -08:00 committed by TensorFlower Gardener
parent a2dac1d40f
commit 3b6251b6cc
17 changed files with 197 additions and 56 deletions

View File

@ -815,9 +815,10 @@ class DatasetBase : public core::RefCounted {
ABSL_DEPRECATED("Use CheckExternalState instead.")
virtual bool IsStateful() const { return false; }
// Indicates whether the dataset depends on any external state. If so, the
// method returns `errors::FailedPrecondition` with a message that identifies
// the external state. Otherwise, the method returns `Status::OK()`.
// Indicates whether the dataset depends on any external state which would
// prevent it from being serializable. If so, the method returns
// `errors::FailedPrecondition` with a message that identifies the external
// state. Otherwise, the method returns `Status::OK()`.
//
// TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
// implementations have an override.
@ -907,17 +908,6 @@ class DatasetBaseIterator : public IteratorBase {
}
Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
Status s = params_.dataset->CheckExternalState();
if (!s.ok()) {
if (ctx->external_state_policy() ==
SerializationContext::ExternalStatePolicy::kWarn) {
LOG(WARNING) << "Dataset contains external state: " << s.ToString();
}
if (ctx->external_state_policy() ==
SerializationContext::ExternalStatePolicy::kFail) {
return s;
}
}
return IteratorBase::Save(ctx, writer);
}

View File

@ -285,6 +285,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState());
TF_RETURN_IF_ERROR(
dataset()->captured_init_func_->CheckExternalState());
TF_RETURN_IF_ERROR(
dataset()->captured_reduce_func_->CheckExternalState());
TF_RETURN_IF_ERROR(
dataset()->captured_finalize_func_->CheckExternalState());
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));

View File

@ -296,6 +296,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState());
TF_RETURN_IF_ERROR(
dataset()->captured_reduce_func_->CheckExternalState());
TF_RETURN_IF_ERROR(
dataset()->captured_window_size_func_->CheckExternalState());
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));

View File

@ -251,6 +251,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {

View File

@ -395,6 +395,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
// The order of locking is important here to avoid deadlock.
mutex_lock l(mu_);
mutex_lock ckpt_l(ckpt_mu_);

View File

@ -349,6 +349,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
explicit ParseExampleFunctor(const Dataset* dataset)
: dataset_(dataset) {}
Status CheckExternalState() override { return Status::OK(); }
void MapFunc(IteratorContext* ctx, const string& prefix,
std::vector<Tensor> input, std::vector<Tensor>* output,
StatusCallback callback) override {

View File

@ -250,6 +250,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
if (!state_.empty()) {

View File

@ -169,6 +169,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(mu_);
if (input_impl_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));

View File

@ -195,6 +195,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(mu_);
if (input_impl_)
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));

View File

@ -163,6 +163,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(mu_);
if (input_impl_) {
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));

View File

@ -194,6 +194,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(mu_);
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
TF_RETURN_IF_ERROR(

View File

@ -175,6 +175,7 @@ class MapDatasetOp::Dataset : public DatasetBase {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
return Status::OK();
}

View File

@ -376,6 +376,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
// TODO(aaudibert): Refactor the implementations to avoid the need for
// `IteratorContext` when saving the state of the iterator.
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
mutex_lock l(*mu_);
wait_for_checkpoint_ = true;
// Wait for all in-flight calls to complete.

View File

@ -190,6 +190,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
ctx, &instantiated_captured_func_);
}
Status CheckExternalState() override {
return dataset_->captured_func_->CheckExternalState();
}
void MapFunc(IteratorContext* ctx, const string& prefix,
std::vector<Tensor> input_element, std::vector<Tensor>* result,
StatusCallback done) override {
@ -375,6 +379,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
}
Status SaveInternal(IteratorStateWriter* writer) override {
TF_RETURN_IF_ERROR(parallel_map_functor_->CheckExternalState());
mutex_lock l(*mu_);
// Wait for all in-flight calls to complete.
while (num_calls_ > 0) {

View File

@ -64,6 +64,12 @@ class ParallelMapFunctor {
// to specify error checking logic that can fail early.
virtual Status InitFunc(IteratorContext* ctx) { return Status::OK(); }
// Indicates whether the functor depends on any external state.
// If so, the method returns `errors::FailedPrecondition` with
// a message that identifies the external state. Otherwise, the method returns
// `Status::OK()`.
virtual Status CheckExternalState() = 0;
// A function that transforms elements of one dataset into another
// asynchronously. The arguments are:
// 1. An `IteratorContext*` for the context in which the function should

View File

@ -66,47 +66,6 @@ class ShuffleDatasetSerializationTest(
seed=seed,
reshuffle_each_iteration=reshuffle_each_iteration), num_outputs)
# TODO(b/133780904): Re-enable this test once randomness state is hoisted out
# of the input pipeline.
@combinations.generate(
combinations.times(
test_base.default_test_combinations(),
combinations.combine(
reshuffle_each_iteration=[True, False],
buffer_size=[1, 3, 5, 8, 10])))
def _testNonDeterministicSeeding(self, reshuffle_each_iteration, buffer_size):
range_limit = 5
num_repeats = 2
num_outputs = range_limit * num_repeats
def ds_fn():
# pylint: disable=cell-var-from-loop
return self._build_shuffle_dataset(
range_limit=range_limit,
num_repeats=num_repeats,
buffer_size=buffer_size,
seed=None, # Iterator seeds are generated non-deterministically.
reshuffle_each_iteration=reshuffle_each_iteration)
# pylint: enable=cell-var-from-loop
# We checkpoint the initial state of the Dataset so that we can restore
# the seeds in the next run. Since the seeding is non-deterministic
# the dataset gets initialized with different seeds each time.
expected = self.gen_outputs(
ds_fn,
break_points=[0],
num_outputs=num_outputs,
ckpt_saved=False,
verify_exhausted=False,
save_checkpoint_at_end=False)
actual = self.gen_outputs(
ds_fn,
break_points=self.gen_break_points(num_outputs),
num_outputs=num_outputs,
ckpt_saved=True,
verify_exhausted=False)
self.match(expected, actual)
@combinations.generate(
combinations.combine(
tf_api_version=1,

View File

@ -20,7 +20,10 @@ from __future__ import print_function
import os
from absl.testing import parameterized
from tensorflow.python.data.experimental.ops import grouping
from tensorflow.python.data.experimental.ops import interleave_ops
from tensorflow.python.data.experimental.ops import scan_ops
from tensorflow.python.data.experimental.ops import take_while_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
@ -31,6 +34,7 @@ from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import parsing_ops
from tensorflow.python.ops import random_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import gfile
from tensorflow.python.platform import test
@ -373,6 +377,160 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
self.assertEqual(i * 2 + j, self.evaluate(get_next()))
checkpoint.save(file_prefix=checkpoint_prefix)
@combinations.generate(test_base.eager_only_combinations())
def testSaveRestoreReshuffleDataset(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.shuffle(10, reshuffle_each_iteration=True)
iterator = iter(dataset)
ckpt = trackable_utils.Checkpoint(
step=variables.Variable(0), iterator=iterator)
manager = checkpoint_management.CheckpointManager(
ckpt, self.get_temp_dir(), max_to_keep=3)
iter1 = [next(iterator).numpy() for _ in range(5)]
manager.save()
iter2 = [next(iterator).numpy() for _ in range(5)]
ckpt.restore(manager.latest_checkpoint)
iter3 = [next(iterator).numpy() for _ in range(5)]
self.assertNotEqual(iter1, iter2)
self.assertCountEqual(iter2, iter3)
def _assertNotCheckpointable(self, dataset):
iterator = iter(dataset)
ckpt = trackable_utils.Checkpoint(
step=variables.Variable(0), iterator=iterator)
manager = checkpoint_management.CheckpointManager(
ckpt, self.get_temp_dir(), max_to_keep=3)
with self.assertRaises(errors.FailedPreconditionError):
manager.save()
@staticmethod
def _statefulInt64Func(_):
return random_ops.random_uniform((), 0, 1, dtypes.int64)
@staticmethod
def _statefulBoolFunc(_):
return random_ops.random_uniform((), 0, 1, dtypes.int64) < 1
@staticmethod
def _statefulDatasetFunc(_):
x = random_ops.random_uniform((), 0, 1, dtypes.int64)
return dataset_ops.Dataset.range(x)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulFilterNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.filter(self._statefulBoolFunc)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulFlatMapNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.flat_map(self._statefulDatasetFunc)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulInterleaveNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.interleave(self._statefulDatasetFunc)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulMapNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.map(self._statefulBoolFunc)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulParallelInterleaveNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.interleave(
self._statefulDatasetFunc, num_parallel_calls=2)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulParallelMapNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.map(self._statefulBoolFunc, num_parallel_calls=2)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulGroupByReducerNotCheckpointable(self):
stateful_key_func = self._statefulInt64Func
key_func = lambda _: math_ops.cast(0, dtypes.int64)
stateful_init_func = self._statefulBoolFunc
init_func = lambda x: True
stateful_reduce_func = lambda _, x: self._statefulBoolFunc(x)
reduce_func = lambda _, x: True
stateful_finalize_func = self._statefulBoolFunc
finalize_func = lambda x: True
test_cases = [
(stateful_key_func, init_func, reduce_func, finalize_func),
(key_func, stateful_init_func, reduce_func, finalize_func),
(key_func, init_func, stateful_reduce_func, finalize_func),
(key_func, init_func, reduce_func, stateful_finalize_func),
]
for key_func, init_func, reduce_func, finalize_func in test_cases:
dataset = dataset_ops.Dataset.range(10)
reducer = grouping.Reducer(init_func, reduce_func, finalize_func)
dataset = dataset.apply(grouping.group_by_reducer(key_func, reducer))
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulGroupByWindowNotCheckpointable(self):
stateful_key_func = self._statefulInt64Func
key_func = lambda _: math_ops.cast(0, dtypes.int64)
stateful_reduce_func = lambda _, x: self._statefulDatasetFunc(x)
reduce_func = lambda _, x: x
stateful_window_func = self._statefulInt64Func
window_func = lambda x: math_ops.cast(0, dtypes.int64)
test_cases = [
(stateful_key_func, reduce_func, window_func),
(key_func, stateful_reduce_func, window_func),
(key_func, reduce_func, stateful_window_func),
]
for key_func_fn, reduce_func_fn, window_func in test_cases:
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
grouping.group_by_window(
key_func_fn, reduce_func_fn, window_size_func=window_func))
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulMapAndBatchNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.map(self._statefulBoolFunc)
dataset = dataset.batch(2)
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulExperimentalParallelInterleaveNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(
interleave_ops.parallel_interleave(self._statefulDatasetFunc, 2))
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulScanNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
def stateful_scan(state, element):
return state, self._statefulBoolFunc(element)
dataset = dataset.apply(scan_ops.scan(0, stateful_scan))
self._assertNotCheckpointable(dataset)
@combinations.generate(test_base.eager_only_combinations())
def testStatefulTakeWhileNotCheckpointable(self):
dataset = dataset_ops.Dataset.range(10)
dataset = dataset.apply(take_while_ops.take_while(self._statefulBoolFunc))
self._assertNotCheckpointable(dataset)
if __name__ == "__main__":
test.main()