Adding checkpointing to StatsAggregatorDataset. We just forward it to the input dataset at the moment.
PiperOrigin-RevId: 284297931 Change-Id: I95aa8937183524328ffb5e82e29b8c9c89fc2b01
This commit is contained in:
parent
164aa8303b
commit
6593263ed7
@ -200,14 +200,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
||||
}
|
||||
|
||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||
return errors::Unimplemented(dataset()->DebugString(),
|
||||
" does not support checkpointing");
|
||||
mutex_lock l(mu_);
|
||||
return SaveInput(writer, input_impl_);
|
||||
}
|
||||
|
||||
Status RestoreInternal(IteratorContext* ctx,
|
||||
IteratorStateReader* reader) override {
|
||||
return errors::Unimplemented(dataset()->DebugString(),
|
||||
" does not support checkpointing");
|
||||
mutex_lock l(mu_);
|
||||
return RestoreInput(ctx, reader, input_impl_);
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -25,7 +25,6 @@ from tensorflow.python.data.experimental.ops import stats_ops
|
||||
from tensorflow.python.data.kernel_tests import test_base
|
||||
from tensorflow.python.data.ops import dataset_ops
|
||||
from tensorflow.python.framework import combinations
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.platform import test
|
||||
@ -104,9 +103,7 @@ class StatsDatasetSerializationTest(
|
||||
|
||||
@combinations.generate(test_base.default_test_combinations())
|
||||
def test_set_stats_aggregator_not_support_checkpointing(self):
|
||||
with self.assertRaisesRegexp(errors.UnimplementedError,
|
||||
"does not support checkpointing"):
|
||||
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
|
||||
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
Loading…
x
Reference in New Issue
Block a user