diff --git a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc index e7d64e10c50..de124e49fe9 100644 --- a/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc +++ b/tensorflow/core/kernels/data/experimental/set_stats_aggregator_dataset_op.cc @@ -200,14 +200,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel { } Status SaveInternal(IteratorStateWriter* writer) override { - return errors::Unimplemented(dataset()->DebugString(), - " does not support checkpointing"); + mutex_lock l(mu_); + return SaveInput(writer, input_impl_); } Status RestoreInternal(IteratorContext* ctx, IteratorStateReader* reader) override { - return errors::Unimplemented(dataset()->DebugString(), - " does not support checkpointing"); + mutex_lock l(mu_); + return RestoreInput(ctx, reader, input_impl_); } private: diff --git a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py index 27b14f0730f..66658ea0a5b 100644 --- a/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py +++ b/tensorflow/python/data/experimental/kernel_tests/serialization/stats_dataset_serialization_test.py @@ -25,7 +25,6 @@ from tensorflow.python.data.experimental.ops import stats_ops from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.ops import dataset_ops from tensorflow.python.framework import combinations -from tensorflow.python.framework import errors from tensorflow.python.framework import ops from tensorflow.python.ops import array_ops from tensorflow.python.platform import test @@ -104,9 +103,7 @@ class StatsDatasetSerializationTest( @combinations.generate(test_base.default_test_combinations()) def test_set_stats_aggregator_not_support_checkpointing(self): - with self.assertRaisesRegexp(errors.UnimplementedError, - "does not support checkpointing"): - self.run_core_tests(self._build_dataset_stats_aggregator, 10) + self.run_core_tests(self._build_dataset_stats_aggregator, 10) if __name__ == "__main__":