Adding checkpointing to StatsAggregatorDataset. We just forward it to the input dataset at the moment.

PiperOrigin-RevId: 284297931
Change-Id: I95aa8937183524328ffb5e82e29b8c9c89fc2b01
This commit is contained in:
Rohan Jain 2019-12-06 18:04:59 -08:00 committed by TensorFlower Gardener
parent 164aa8303b
commit 6593263ed7
2 changed files with 5 additions and 8 deletions

View File

@ -200,14 +200,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
}
Status SaveInternal(IteratorStateWriter* writer) override {
return errors::Unimplemented(dataset()->DebugString(),
" does not support checkpointing");
mutex_lock l(mu_);
return SaveInput(writer, input_impl_);
}
Status RestoreInternal(IteratorContext* ctx,
IteratorStateReader* reader) override {
return errors::Unimplemented(dataset()->DebugString(),
" does not support checkpointing");
mutex_lock l(mu_);
return RestoreInput(ctx, reader, input_impl_);
}
private:

View File

@ -25,7 +25,6 @@ from tensorflow.python.data.experimental.ops import stats_ops
from tensorflow.python.data.kernel_tests import test_base
from tensorflow.python.data.ops import dataset_ops
from tensorflow.python.framework import combinations
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -104,9 +103,7 @@ class StatsDatasetSerializationTest(
@combinations.generate(test_base.default_test_combinations())
def test_set_stats_aggregator_not_support_checkpointing(self):
with self.assertRaisesRegexp(errors.UnimplementedError,
"does not support checkpointing"):
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
if __name__ == "__main__":