Adding checkpointing to StatsAggregatorDataset. We just forward it to the input dataset at the moment.

PiperOrigin-RevId: 284297931
Change-Id: I95aa8937183524328ffb5e82e29b8c9c89fc2b01
This commit is contained in:
Rohan Jain 2019-12-06 18:04:59 -08:00 committed by TensorFlower Gardener
parent 164aa8303b
commit 6593263ed7
2 changed files with 5 additions and 8 deletions

View File

@ -200,14 +200,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
} }
Status SaveInternal(IteratorStateWriter* writer) override { Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented(dataset()->DebugString(), mutex_lock l(mu_);
" does not support checkpointing"); return SaveInput(writer, input_impl_);
} }
Status RestoreInternal(IteratorContext* ctx, Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override { IteratorStateReader* reader) override {
return errors::Unimplemented(dataset()->DebugString(), mutex_lock l(mu_);
" does not support checkpointing"); return RestoreInput(ctx, reader, input_impl_);
} }
private: private:

View File

@ -25,7 +25,6 @@ from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.kernel_tests import test_base from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test from tensorflow.python.platform import test
@ -104,9 +103,7 @@ class StatsDatasetSerializationTest(
@combinations.generate(test_base.default_test_combinations()) @combinations.generate(test_base.default_test_combinations())
def test_set_stats_aggregator_not_support_checkpointing(self): def test_set_stats_aggregator_not_support_checkpointing(self):
with self.assertRaisesRegexp(errors.UnimplementedError, self.run_core_tests(self._build_dataset_stats_aggregator, 10)
"does not support checkpointing"):
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
if __name__ == "__main__": if __name__ == "__main__":