Enable checkpointing for ReshuffleDataset.
Previously, we would call CheckExternalState before checkpointing. But this was to strong a check, since some external state, such as TF resources, can be checkpointed. Instead of calling CheckExternalState, we now leave it up to Datasets' SaveInternal implementations to return an error if they contain non-checkpointable state. PiperOrigin-RevId: 299012205 Change-Id: I8638d15073758cd0070bdd4879202ad202a99850
This commit is contained in:
parent
a2dac1d40f
commit
3b6251b6cc
@ -815,9 +815,10 @@ class DatasetBase : public core::RefCounted {
|
||||
ABSL_DEPRECATED("Use CheckExternalState instead.")
|
||||
virtual bool IsStateful() const { return false; }
|
||||
|
||||
// Indicates whether the dataset depends on any external state. If so, the
|
||||
// method returns `errors::FailedPrecondition` with a message that identifies
|
||||
// the external state. Otherwise, the method returns `Status::OK()`.
|
||||
// Indicates whether the dataset depends on any external state which would
|
||||
// prevent it from being serializable. If so, the method returns
|
||||
// `errors::FailedPrecondition` with a message that identifies the external
|
||||
// state. Otherwise, the method returns `Status::OK()`.
|
||||
//
|
||||
// TODO(jsimsa): Make this method pure virtual once all `DatasetBase`
|
||||
// implementations have an override.
|
||||
@ -907,17 +908,6 @@ class DatasetBaseIterator : public IteratorBase {
|
||||
}
|
||||
|
||||
Status Save(SerializationContext* ctx, IteratorStateWriter* writer) final {
|
||||
Status s = params_.dataset->CheckExternalState();
|
||||
if (!s.ok()) {
|
||||
if (ctx->external_state_policy() ==
|
||||
SerializationContext::ExternalStatePolicy::kWarn) {
|
||||
LOG(WARNING) << "Dataset contains external state: " << s.ToString();
|
||||
}
|
||||
if (ctx->external_state_policy() ==
|
||||
SerializationContext::ExternalStatePolicy::kFail) {
|
||||
return s;
|
||||
}
|
||||
}
|
||||
return IteratorBase::Save(ctx, writer);
|
||||
}
|
||||
|
||||
|
@ -285,6 +285,13 @@ class GroupByReducerDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_init_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_reduce_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_finalize_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
|
||||
|
@ -296,6 +296,11 @@ class GroupByWindowDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_key_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_reduce_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(
|
||||
dataset()->captured_window_size_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
|
||||
|
@ -251,6 +251,7 @@ class MapAndBatchDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
|
@ -395,6 +395,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
// The order of locking is important here to avoid deadlock.
|
||||
mutex_lock l(mu_);
|
||||
mutex_lock ckpt_l(ckpt_mu_);
|
||||
|
@ -349,6 +349,8 @@ class ParseExampleDatasetOp : public UnaryDatasetOpKernel {
|
||||
explicit ParseExampleFunctor(const Dataset* dataset)
|
||||
: dataset_(dataset) {}
|
||||
|
||||
Status CheckExternalState() override { return Status::OK(); }
|
||||
|
||||
void MapFunc(IteratorContext* ctx, const string& prefix,
|
||||
std::vector<Tensor> input, std::vector<Tensor>* output,
|
||||
StatusCallback callback) override {
|
||||
|
@ -250,6 +250,7 @@ class ScanDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
if (!state_.empty()) {
|
||||
|
@ -169,6 +169,7 @@ class TakeWhileDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
|
@ -195,6 +195,7 @@ class FilterDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_)
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
|
@ -163,6 +163,7 @@ class FlatMapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
if (input_impl_) {
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
|
@ -194,6 +194,7 @@ class InterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(mu_);
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -175,6 +175,7 @@ class MapDatasetOp::Dataset : public DatasetBase {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
TF_RETURN_IF_ERROR(SaveInput(writer, input_impl_));
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -376,6 +376,7 @@ class ParallelInterleaveDatasetOp::Dataset : public DatasetBase {
|
||||
// TODO(aaudibert): Refactor the implementations to avoid the need for
|
||||
// `IteratorContext` when saving the state of the iterator.
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(dataset()->captured_func_->CheckExternalState());
|
||||
mutex_lock l(*mu_);
|
||||
wait_for_checkpoint_ = true;
|
||||
// Wait for all in-flight calls to complete.
|
||||
|
@ -190,6 +190,10 @@ class ParallelMapDatasetOp::Dataset : public DatasetBase {
|
||||
ctx, &instantiated_captured_func_);
|
||||
}
|
||||
|
||||
Status CheckExternalState() override {
|
||||
return dataset_->captured_func_->CheckExternalState();
|
||||
}
|
||||
|
||||
void MapFunc(IteratorContext* ctx, const string& prefix,
|
||||
std::vector<Tensor> input_element, std::vector<Tensor>* result,
|
||||
StatusCallback done) override {
|
||||
@ -375,6 +379,7 @@ class ParallelMapIterator : public DatasetBaseIterator {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
TF_RETURN_IF_ERROR(parallel_map_functor_->CheckExternalState());
|
||||
mutex_lock l(*mu_);
|
||||
// Wait for all in-flight calls to complete.
|
||||
while (num_calls_ > 0) {
|
||||
|
@ -64,6 +64,12 @@ class ParallelMapFunctor {
|
||||
// to specify error checking logic that can fail early.
|
||||
virtual Status InitFunc(IteratorContext* ctx) { return Status::OK(); }
|
||||
|
||||
// Indicates whether the functor depends on any external state.
|
||||
// If so, the method returns `errors::FailedPrecondition` with
|
||||
// a message that identifies the external state. Otherwise, the method returns
|
||||
// `Status::OK()`.
|
||||
virtual Status CheckExternalState() = 0;
|
||||
|
||||
// A function that transforms elements of one dataset into another
|
||||
// asynchronously. The arguments are:
|
||||
// 1. An `IteratorContext*` for the context in which the function should
|
||||
|
@ -66,47 +66,6 @@ class ShuffleDatasetSerializationTest(
|
||||
seed=seed,
|
||||
reshuffle_each_iteration=reshuffle_each_iteration), num_outputs)
|
||||
|
||||
# TODO(b/133780904): Re-enable this test once randomness state is hoisted out
|
||||
# of the input pipeline.
|
||||
@combinations.generate(
|
||||
combinations.times(
|
||||
test_base.default_test_combinations(),
|
||||
combinations.combine(
|
||||
reshuffle_each_iteration=[True, False],
|
||||
buffer_size=[1, 3, 5, 8, 10])))
|
||||
def _testNonDeterministicSeeding(self, reshuffle_each_iteration, buffer_size):
|
||||
range_limit = 5
|
||||
num_repeats = 2
|
||||
num_outputs = range_limit * num_repeats
|
||||
|
||||
def ds_fn():
|
||||
# pylint: disable=cell-var-from-loop
|
||||
return self._build_shuffle_dataset(
|
||||
range_limit=range_limit,
|
||||
num_repeats=num_repeats,
|
||||
buffer_size=buffer_size,
|
||||
seed=None, # Iterator seeds are generated non-deterministically.
|
||||
reshuffle_each_iteration=reshuffle_each_iteration)
|
||||
# pylint: enable=cell-var-from-loop
|
||||
|
||||
# We checkpoint the initial state of the Dataset so that we can restore
|
||||
# the seeds in the next run. Since the seeding is non-deterministic
|
||||
# the dataset gets initialized with different seeds each time.
|
||||
expected = self.gen_outputs(
|
||||
ds_fn,
|
||||
break_points=[0],
|
||||
num_outputs=num_outputs,
|
||||
ckpt_saved=False,
|
||||
verify_exhausted=False,
|
||||
save_checkpoint_at_end=False)
|
||||
actual = self.gen_outputs(
|
||||
ds_fn,
|
||||
break_points=self.gen_break_points(num_outputs),
|
||||
num_outputs=num_outputs,
|
||||
ckpt_saved=True,
|
||||
verify_exhausted=False)
|
||||
self.match(expected, actual)
|
||||
|
||||
@combinations.generate(
|
||||
combinations.combine(
|
||||
tf_api_version=1,
|
||||
|
@ -20,7 +20,10 @@ from __future__ import print_function
|
||||
import os
|
||||
|
||||
from absl.testing import parameterized
|
||||
|
||||
from tensorflow.python.data.experimental.ops import grouping
|
||||
from tensorflow.python.data.experimental.ops import interleave_ops
|
||||
from tensorflow.python.data.experimental.ops import scan_ops
|
||||
from tensorflow.python.data.experimental.ops import take_while_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
@ -31,6 +34,7 @@ from tensorflow.python.ops import gen_dataset_ops
|
||||
from tensorflow.python.ops import io_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import parsing_ops
|
||||
from tensorflow.python.ops import random_ops
|
||||
from tensorflow.python.ops import variables
|
||||
from tensorflow.python.platform import gfile
|
||||
from tensorflow.python.platform import test
|
||||
@ -373,6 +377,160 @@ class CheckpointTest(test_base.DatasetTestBase, parameterized.TestCase):
|
||||
self.assertEqual(i * 2 + j, self.evaluate(get_next()))
|
||||
checkpoint.save(file_prefix=checkpoint_prefix)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testSaveRestoreReshuffleDataset(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.shuffle(10, reshuffle_each_iteration=True)
|
||||
iterator = iter(dataset)
|
||||
ckpt = trackable_utils.Checkpoint(
|
||||
step=variables.Variable(0), iterator=iterator)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
ckpt, self.get_temp_dir(), max_to_keep=3)
|
||||
|
||||
iter1 = [next(iterator).numpy() for _ in range(5)]
|
||||
|
||||
manager.save()
|
||||
iter2 = [next(iterator).numpy() for _ in range(5)]
|
||||
|
||||
ckpt.restore(manager.latest_checkpoint)
|
||||
iter3 = [next(iterator).numpy() for _ in range(5)]
|
||||
|
||||
self.assertNotEqual(iter1, iter2)
|
||||
self.assertCountEqual(iter2, iter3)
|
||||
|
||||
def _assertNotCheckpointable(self, dataset):
|
||||
iterator = iter(dataset)
|
||||
ckpt = trackable_utils.Checkpoint(
|
||||
step=variables.Variable(0), iterator=iterator)
|
||||
manager = checkpoint_management.CheckpointManager(
|
||||
ckpt, self.get_temp_dir(), max_to_keep=3)
|
||||
with self.assertRaises(errors.FailedPreconditionError):
|
||||
manager.save()
|
||||
|
||||
@staticmethod
|
||||
def _statefulInt64Func(_):
|
||||
return random_ops.random_uniform((), 0, 1, dtypes.int64)
|
||||
|
||||
@staticmethod
|
||||
def _statefulBoolFunc(_):
|
||||
return random_ops.random_uniform((), 0, 1, dtypes.int64) < 1
|
||||
|
||||
@staticmethod
|
||||
def _statefulDatasetFunc(_):
|
||||
x = random_ops.random_uniform((), 0, 1, dtypes.int64)
|
||||
return dataset_ops.Dataset.range(x)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulFilterNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.filter(self._statefulBoolFunc)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulFlatMapNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.flat_map(self._statefulDatasetFunc)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulInterleaveNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.interleave(self._statefulDatasetFunc)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulMapNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.map(self._statefulBoolFunc)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulParallelInterleaveNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.interleave(
|
||||
self._statefulDatasetFunc, num_parallel_calls=2)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulParallelMapNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.map(self._statefulBoolFunc, num_parallel_calls=2)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulGroupByReducerNotCheckpointable(self):
|
||||
stateful_key_func = self._statefulInt64Func
|
||||
key_func = lambda _: math_ops.cast(0, dtypes.int64)
|
||||
stateful_init_func = self._statefulBoolFunc
|
||||
init_func = lambda x: True
|
||||
stateful_reduce_func = lambda _, x: self._statefulBoolFunc(x)
|
||||
reduce_func = lambda _, x: True
|
||||
stateful_finalize_func = self._statefulBoolFunc
|
||||
finalize_func = lambda x: True
|
||||
|
||||
test_cases = [
|
||||
(stateful_key_func, init_func, reduce_func, finalize_func),
|
||||
(key_func, stateful_init_func, reduce_func, finalize_func),
|
||||
(key_func, init_func, stateful_reduce_func, finalize_func),
|
||||
(key_func, init_func, reduce_func, stateful_finalize_func),
|
||||
]
|
||||
for key_func, init_func, reduce_func, finalize_func in test_cases:
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
reducer = grouping.Reducer(init_func, reduce_func, finalize_func)
|
||||
dataset = dataset.apply(grouping.group_by_reducer(key_func, reducer))
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulGroupByWindowNotCheckpointable(self):
|
||||
stateful_key_func = self._statefulInt64Func
|
||||
key_func = lambda _: math_ops.cast(0, dtypes.int64)
|
||||
stateful_reduce_func = lambda _, x: self._statefulDatasetFunc(x)
|
||||
reduce_func = lambda _, x: x
|
||||
stateful_window_func = self._statefulInt64Func
|
||||
window_func = lambda x: math_ops.cast(0, dtypes.int64)
|
||||
|
||||
test_cases = [
|
||||
(stateful_key_func, reduce_func, window_func),
|
||||
(key_func, stateful_reduce_func, window_func),
|
||||
(key_func, reduce_func, stateful_window_func),
|
||||
]
|
||||
for key_func_fn, reduce_func_fn, window_func in test_cases:
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(
|
||||
grouping.group_by_window(
|
||||
key_func_fn, reduce_func_fn, window_size_func=window_func))
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulMapAndBatchNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.map(self._statefulBoolFunc)
|
||||
dataset = dataset.batch(2)
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulExperimentalParallelInterleaveNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(
|
||||
interleave_ops.parallel_interleave(self._statefulDatasetFunc, 2))
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulScanNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
|
||||
def stateful_scan(state, element):
|
||||
return state, self._statefulBoolFunc(element)
|
||||
|
||||
dataset = dataset.apply(scan_ops.scan(0, stateful_scan))
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
@combinations.generate(test_base.eager_only_combinations())
|
||||
def testStatefulTakeWhileNotCheckpointable(self):
|
||||
dataset = dataset_ops.Dataset.range(10)
|
||||
dataset = dataset.apply(take_while_ops.take_while(self._statefulBoolFunc))
|
||||
self._assertNotCheckpointable(dataset)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
test.main()
|
||||
|
Loading…
Reference in New Issue
Block a user