Adding checkpointing to StatsAggregatorDataset. We just forward it to the input dataset at the moment.
PiperOrigin-RevId: 284297931 Change-Id: I95aa8937183524328ffb5e82e29b8c9c89fc2b01
This commit is contained in:
parent
164aa8303b
commit
6593263ed7
@ -200,14 +200,14 @@ class SetStatsAggregatorDatasetOp : public UnaryDatasetOpKernel {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status SaveInternal(IteratorStateWriter* writer) override {
|
Status SaveInternal(IteratorStateWriter* writer) override {
|
||||||
return errors::Unimplemented(dataset()->DebugString(),
|
mutex_lock l(mu_);
|
||||||
" does not support checkpointing");
|
return SaveInput(writer, input_impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
Status RestoreInternal(IteratorContext* ctx,
|
Status RestoreInternal(IteratorContext* ctx,
|
||||||
IteratorStateReader* reader) override {
|
IteratorStateReader* reader) override {
|
||||||
return errors::Unimplemented(dataset()->DebugString(),
|
mutex_lock l(mu_);
|
||||||
" does not support checkpointing");
|
return RestoreInput(ctx, reader, input_impl_);
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
|
@ -25,7 +25,6 @@ from tensorflow.python.data.experimental.ops import stats_ops
|
|||||||
from tensorflow.python.data.kernel_tests import test_base
|
from tensorflow.python.data.kernel_tests import test_base
|
||||||
from tensorflow.python.data.ops import dataset_ops
|
from tensorflow.python.data.ops import dataset_ops
|
||||||
from tensorflow.python.framework import combinations
|
from tensorflow.python.framework import combinations
|
||||||
from tensorflow.python.framework import errors
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.platform import test
|
from tensorflow.python.platform import test
|
||||||
@ -104,8 +103,6 @@ class StatsDatasetSerializationTest(
|
|||||||
|
|
||||||
@combinations.generate(test_base.default_test_combinations())
|
@combinations.generate(test_base.default_test_combinations())
|
||||||
def test_set_stats_aggregator_not_support_checkpointing(self):
|
def test_set_stats_aggregator_not_support_checkpointing(self):
|
||||||
with self.assertRaisesRegexp(errors.UnimplementedError,
|
|
||||||
"does not support checkpointing"):
|
|
||||||
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
|
self.run_core_tests(self._build_dataset_stats_aggregator, 10)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user